[FIX] purchase: cr -> self.cr in report
[odoo/odoo.git] / bin / netsvc.py
index 8624f56..475bbc3 100644 (file)
@@ -1,12 +1,12 @@
-#!/usr/bin/python
-# -*- encoding: utf-8 -*-
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
 ##############################################################################
 #
 #    OpenERP, Open Source Management Solution
-#    Copyright (C) 2004-2008 Tiny SPRL (<http://tiny.be>). All Rights Reserved
+#    Copyright (C) 2004-2009 Tiny SPRL (<http://tiny.be>). All Rights Reserved
 #    The refactoring about the OpenSSL support come from Tryton
-#    Copyright (C) 2007-2008 Cédric Krier.
-#    Copyright (C) 2007-2008 Bertrand Chenal.
+#    Copyright (C) 2007-2009 Cédric Krier.
+#    Copyright (C) 2007-2009 Bertrand Chenal.
 #    Copyright (C) 2008 B2CK SPRL.
 #
 #    This program is free software: you can redistribute it and/or modify
 #
 ##############################################################################
 
-
-import SimpleXMLRPCServer
-import SocketServer
 import logging
 import logging.handlers
-import os
-import signal
-import socket
 import sys
 import threading
 import time
-import xmlrpclib
 import release
-
-_service = {}
-_group = {}
-_res_id = 1
-_res = {}
-
-class ServiceEndPointCall(object):
-    def __init__(self, id, method):
-        self._id = id
-        self._meth = method
-
-    def __call__(self, *args):
-        _res[self._id] = self._meth(*args)
-        return self._id
-
-
-class ServiceEndPoint(object):
-    def __init__(self, name, id):
-        self._id = id
-        self._meth = {}
-        s = _service[name]
-        for m in s._method:
-            self._meth[m] = s._method[m]
-
-    def __getattr__(self, name):
-        return ServiceEndPointCall(self._id, self._meth[name])
-
+from pprint import pformat
+import warnings
 
 class Service(object):
-    _serviceEndPointID = 0
+    """ Base class for *Local* services
 
+        Functionality here is trusted, no authentication.
+    """
+    _services = {}
     def __init__(self, name, audience=''):
-        _service[name] = self
+        Service._services[name] = self
         self.__name = name
-        self._method = {}
-        self.exportedMethods = None
-        self._response_process = None
-        self._response_process_id = None
-        self._response = None
+        self._methods = {}
 
     def joinGroup(self, name):
-        if not name in _group:
-            _group[name] = {}
-        _group[name][self.__name] = self
-
-    def exportMethod(self, m):
-        if callable(m):
-            self._method[m.__name__] = m
-
-    def serviceEndPoint(self, s):
-        if Service._serviceEndPointID >= 2**16:
-            Service._serviceEndPointID = 0
-        Service._serviceEndPointID += 1
-        return ServiceEndPoint(s, self._serviceEndPointID)
+        raise Exception("No group for local services")
+        #GROUPS.setdefault(name, {})[self.__name] = self
 
-    def conversationId(self):
-        return 1
+    @classmethod
+    def exists(cls, name):
+        return name in cls._services
 
-    def processResponse(self, s, id):
-        self._response_process, self._response_process_id = s, id
+    @classmethod
+    def remove(cls, name):
+        if cls.exists(name):
+            cls._services.pop(name)
 
-    def processFailure(self, s, id):
-        pass
-
-    def resumeResponse(self, s):
-        pass
-
-    def cancelResponse(self, s):
-        pass
-
-    def suspendResponse(self, s):
-        if self._response_process:
-            self._response_process(self._response_process_id,
-                                   _res[self._response_process_id])
-        self._response_process = None
-        self._response = s(self._response_process_id)
+    def exportMethod(self, method):
+        if callable(method):
+            self._methods[method.__name__] = method
 
     def abortResponse(self, error, description, origin, details):
-        import tools
         if not tools.config['debug_mode']:
             raise Exception("%s -- %s\n\n%s"%(origin, description, details))
         else:
             raise
 
-    def currentFailure(self, s):
-        pass
-
-
-class LocalService(Service):
+class LocalService(object):
+    """ Proxy for local services. 
+    
+        Any instance of this class will behave like the single instance
+        of Service(name)
+    """
+    __logger = logging.getLogger('service')
     def __init__(self, name):
         self.__name = name
         try:
-            s = _service[name]
-            self._service = s
-            for m in s._method:
-                setattr(self, m, s._method[m])
+            self._service = Service._services[name]
+            for method_name, method_definition in self._service._methods.items():
+                setattr(self, method_name, method_definition)
         except KeyError, keyError:
-            Logger().notifyChannel('module', LOG_ERROR, 'This service does not exists: %s' % (str(keyError),) )
+            self.__logger.error('This service does not exist: %s' % (str(keyError),) )
             raise
 
-def service_exist(name):
-    return (name in _service) and bool(_service[name])
+    def __call__(self, method, *params):
+        return getattr(self, method)(*params)
+
+class ExportService(object):
+    """ Proxy for exported services. 
+
+    All methods here should take an AuthProxy as their first parameter. It
+    will be appended by the calling framework.
+
+    Note that this class has no direct proxy, capable of calling 
+    eservice.method(). Rather, the proxy should call 
+    dispatch(method,auth,params)
+    """
+    
+    _services = {}
+    _groups = {}
+    
+    def __init__(self, name, audience=''):
+        ExportService._services[name] = self
+        self.__name = name
+
+    def joinGroup(self, name):
+        ExportService._groups.setdefault(name, {})[self.__name] = self
+
+    @classmethod
+    def getService(cls,name):
+        return cls._services[name]
+
+    def dispatch(self, method, auth, params):
+        raise Exception("stub dispatch at %s" % self.__name)
+        
+    def new_dispatch(self,method,auth,params):
+        raise Exception("stub dispatch at %s" % self.__name)
+
+    def abortResponse(self, error, description, origin, details):
+        if not tools.config['debug_mode']:
+            raise Exception("%s -- %s\n\n%s"%(origin, description, details))
+        else:
+            raise
 
 LOG_NOTSET = 'notset'
 LOG_DEBUG_RPC = 'debug_rpc'
 LOG_DEBUG = 'debug'
+LOG_TEST = 'test'
 LOG_INFO = 'info'
 LOG_WARNING = 'warn'
 LOG_ERROR = 'error'
 LOG_CRITICAL = 'critical'
 
-# add new log level below DEBUG
-logging.DEBUG_RPC = logging.DEBUG - 1
+logging.DEBUG_RPC = logging.DEBUG - 2
+logging.addLevelName(logging.DEBUG_RPC, 'DEBUG_RPC')
+
+logging.TEST = logging.INFO - 5
+logging.addLevelName(logging.TEST, 'TEST')
 
 def init_logger():
-    from tools import config
     import os
+    from tools.translate import resetlocale
+    resetlocale()
 
     logger = logging.getLogger()
+    # create a format for log messages and dates
+    formatter = logging.Formatter('[%(asctime)s] %(levelname)s:%(name)s:%(message)s')
 
-    if config['syslog']:
+    if tools.config['syslog']:
         # SysLog Handler
         if os.name == 'nt':
-            sysloghandler = logging.handlers.NTEventLogHandler("%s %s" %
-                                                         (release.description,
-                                                          release.version))
+            handler = logging.handlers.NTEventLogHandler("%s %s" % (release.description, release.version))
         else:
-            sysloghandler = logging.handlers.SysLogHandler('/dev/log')
-        formatter = logging.Formatter('%(application)s:%(uncoloredlevelname)s:%(name)s:%(message)s')
-        sysloghandler.setFormatter(formatter)
-        logger.addHandler(sysloghandler)
+            handler = logging.handlers.SysLogHandler('/dev/log')
+        formatter = logging.Formatter("%s %s" % (release.description, release.version) + ':%(levelname)s:%(name)s:%(message)s')
 
-    # create a format for log messages and dates
-    formatter = logging.Formatter('[%(asctime)s] %(levelname)s:%(name)s:%(message)s', '%a %b %d %Y %H:%M:%S')
-    if config['logfile']:
+    elif tools.config['logfile']:
         # LogFile Handler
-        logf = config['logfile']
+        logf = tools.config['logfile']
         try:
             dirname = os.path.dirname(logf)
             if dirname and not os.path.isdir(dirname):
                 os.makedirs(dirname)
-            handler = logging.handlers.TimedRotatingFileHandler(logf,'D',1,30)
+            if tools.config['logrotate'] is not False:
+                handler = logging.handlers.TimedRotatingFileHandler(logf,'D',1,30)
+            elif os.name == 'posix':
+                handler = logging.handlers.WatchedFileHandler(logf)
+            else:
+                handler = logging.handlers.FileHandler(logf)
         except Exception, ex:
-            sys.stderr.write("ERROR: couldn't create the logfile directory\n")
+            sys.stderr.write("ERROR: couldn't create the logfile directory. Logging to the standard output.\n")
             handler = logging.StreamHandler(sys.stdout)
     else:
         # Normal Handler on standard output
@@ -192,9 +182,9 @@ def init_logger():
 
     # add the handler to the root logger
     logger.addHandler(handler)
-    logger.setLevel(config['log_level'] or '0')
+    logger.setLevel(int(tools.config['log_level'] or '0'))
 
-    if isinstance(handler, logging.StreamHandler) and os.name != 'nt':
+    if (not isinstance(handler, logging.FileHandler)) and os.name != 'nt':
         # change color of level names
         # uses of ANSI color codes
         # see http://pueblo.sourceforge.net/doc/manual/ansi_color_codes.html
@@ -207,6 +197,7 @@ def init_logger():
             'DEBUG_RPC': ('blue', 'white'),
             'DEBUG': ('blue', 'default'),
             'INFO': ('green', 'default'),
+            'TEST': ('white', 'blue'),
             'WARNING': ('yellow', 'default'),
             'ERROR': ('red', 'default'),
             'CRITICAL': ('white', 'red'),
@@ -218,286 +209,177 @@ def init_logger():
 
 
 class Logger(object):
-    def uncoloredlevelname(self, level):
-        # The level'names are globals to all loggers, so we must strip-off the
-        # color formatting for some specific logger (i.e: syslog)
-        levelname = logging.getLevelName(getattr(logging, level.upper(), 0))
-        if levelname.startswith("\x1b["):
-            return levelname[10:-4]
-        return levelname
+    def __init__(self):
+        warnings.warn("The netsvc.Logger API shouldn't be used anymore, please "
+                      "use the standard `logging.getLogger` API instead",
+                      PendingDeprecationWarning, stacklevel=2)
+        super(Logger, self).__init__()
+
     def notifyChannel(self, name, level, msg):
-        log = logging.getLogger(name)
+        warnings.warn("notifyChannel API shouldn't be used anymore, please use "
+                      "the standard `logging` module instead",
+                      PendingDeprecationWarning, stacklevel=2)
+        from service.web_services import common
 
-        if level == LOG_DEBUG_RPC and not hasattr(log, level):
-            fct = lambda msg, *args, **kwargs: log.log(logging.DEBUG_RPC, msg, *args, **kwargs)
-            setattr(log, LOG_DEBUG_RPC, fct)
+        log = logging.getLogger(tools.ustr(name))
+
+        if level in [LOG_DEBUG_RPC, LOG_TEST] and not hasattr(log, level):
+            fct = lambda msg, *args, **kwargs: log.log(getattr(logging, level.upper()), msg, *args, **kwargs)
+            setattr(log, level, fct)
 
-        extra = {
-            'uncoloredlevelname': self.uncoloredlevelname(level), 
-            'application' : "%s %s" % (release.description, release.version),
-        }
 
         level_method = getattr(log, level)
 
-        result = str(msg).strip().split('\n')
-        if len(result)>1:
-            for idx, s in enumerate(result):
-                level_method('[%02d]: %s' % (idx+1, s,), extra=extra)
-        elif result:
-            level_method(result[0], extra=extra)
+        if isinstance(msg, Exception):
+            msg = tools.exception_to_unicode(msg)
 
+        try:
+            msg = tools.ustr(msg).strip()
+            if level in (LOG_ERROR, LOG_CRITICAL) and tools.config.get_misc('debug','env_info',False):
+                msg = common().exp_get_server_environment() + "\n" + msg
+
+            result = msg.split('\n')
+        except UnicodeDecodeError:
+            result = msg.strip().split('\n')
+        try:
+            if len(result)>1:
+                for idx, s in enumerate(result):
+                    level_method('[%02d]: %s' % (idx+1, s,))
+            elif result:
+                level_method(result[0])
+        except IOError,e:
+            # TODO: perhaps reset the logger streams?
+            #if logrotate closes our files, we end up here..
+            pass
+        except:
+            # better ignore the exception and carry on..
+            pass
+
+    def set_loglevel(self, level):
+        log = logging.getLogger()
+        log.setLevel(logging.INFO) # make sure next msg is printed
+        log.info("Log level changed to %s" % logging.getLevelName(level))
+        log.setLevel(level)
+
+    def shutdown(self):
+        logging.shutdown()
+
+import tools
 init_logger()
 
 class Agent(object):
-    _timers = []
+    _timers = {}
     _logger = Logger()
 
-    def setAlarm(self, fn, dt, args=None, kwargs=None):
-        if not args:
-            args = []
-        if not kwargs:
-            kwargs = {}
+    __logger = logging.getLogger('timer')
+
+    def setAlarm(self, fn, dt, db_name, *args, **kwargs):
         wait = dt - time.time()
         if wait > 0:
-            self._logger.notifyChannel('timers', LOG_DEBUG, "Job scheduled in %s seconds for %s.%s" % (wait, fn.im_class.__name__, fn.func_name))
+            self.__logger.debug("Job scheduled in %.3g seconds for %s.%s" % (wait, fn.im_class.__name__, fn.func_name))
             timer = threading.Timer(wait, fn, args, kwargs)
             timer.start()
-            self._timers.append(timer)
-        for timer in self._timers[:]:
-            if not timer.isAlive():
-                self._timers.remove(timer)
-
+            self._timers.setdefault(db_name, []).append(timer)
+
+        for db in self._timers:
+            for timer in self._timers[db]:
+                if not timer.isAlive():
+                    self._timers[db].remove(timer)
+
+    @classmethod
+    def cancel(cls, db_name):
+        """Cancel all timers for a given database. If None passed, all timers are cancelled"""
+        for db in cls._timers:
+            if db_name is None or db == db_name:
+                for timer in cls._timers[db]:
+                    timer.cancel()
+
+    @classmethod
     def quit(cls):
-        for timer in cls._timers:
-            timer.cancel()
-    quit = classmethod(quit)
+        cls.cancel(None)
 
-class xmlrpc(object):
-    class RpcGateway(object):
-        def __init__(self, name):
-            self.name = name
+import traceback
 
-class GenericXMLRPCRequestHandler:
-    def log(self, title, msg):
-        from pprint import pformat
-        Logger().notifyChannel('XMLRPC-%s' % title, LOG_DEBUG_RPC, pformat(msg))
+class Server:
+    """ Generic interface for all servers with an event loop etc.
+        Override this to impement http, net-rpc etc. servers.
 
-    def _dispatch(self, method, params):
-        import traceback
-        try:
-            self.log('method', method)
-            self.log('params', params)
-            n = self.path.split("/")[-1]
-            s = LocalService(n)
-            m = getattr(s, method)
-            s._service._response = None
-            r = m(*params)
-            self.log('result', r)
-            res = s._service._response
-            if res is not None:
-                r = res
-            self.log('res',r)
-            return r
-        except Exception, e:
-            self.log('exception', e)
-            tb_s = reduce(lambda x, y: x+y, traceback.format_exception(sys.exc_type, sys.exc_value, sys.exc_traceback))
-            s = str(e)
-            import tools
-            if tools.config['debug_mode']:
-                import pdb
-                tb = sys.exc_info()[2]
-                pdb.post_mortem(tb)
-            raise xmlrpclib.Fault(s, tb_s)
-
-class SSLSocket(object):
-    def __init__(self, socket):
-        if not hasattr(socket, 'sock_shutdown'):
-            from OpenSSL import SSL
-            ctx = SSL.Context(SSL.SSLv23_METHOD)
-            ctx.use_privatekey_file('server.pkey')
-            ctx.use_certificate_file('server.cert')
-            self.socket = SSL.Connection(ctx, socket)
-        else:
-            self.socket = socket
-
-    def shutdown(self, how):
-        return self.socket.sock_shutdown(how)
-
-    def __getattr__(self, name):
-        return getattr(self.socket, name)
-
-class SimpleXMLRPCRequestHandler(GenericXMLRPCRequestHandler, SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
-    rpc_paths = map(lambda s: '/xmlrpc/%s' % s, _service)
-
-class SecureXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
-    def setup(self):
-        self.connection = SSLSocket(self.request)
-        self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
-        self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
-
-class SimpleThreadedXMLRPCServer(SocketServer.ThreadingMixIn, SimpleXMLRPCServer.SimpleXMLRPCServer):
-    def server_bind(self):
-        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        SimpleXMLRPCServer.SimpleXMLRPCServer.server_bind(self)
-
-class SecureThreadedXMLRPCServer(SimpleThreadedXMLRPCServer):
-    def __init__(self, server_address, HandlerClass, logRequests=1):
-        SimpleThreadedXMLRPCServer.__init__(self, server_address, HandlerClass, logRequests)
-        self.socket = SSLSocket(socket.socket(self.address_family, self.socket_type))
-        self.server_bind()
-        self.server_activate()
-
-class HttpDaemon(threading.Thread):
-    def __init__(self, interface, port, secure=False):
-        threading.Thread.__init__(self)
-        self.__port = port
-        self.__interface = interface
-        self.secure = bool(secure)
-        handler_class = (SimpleXMLRPCRequestHandler, SecureXMLRPCRequestHandler)[self.secure]
-        server_class = (SimpleThreadedXMLRPCServer, SecureThreadedXMLRPCServer)[self.secure]
-
-        if self.secure:
-            from OpenSSL.SSL import Error as SSLError
-        else:
-            class SSLError(Exception): pass
-        
-        try: 
-            self.server = server_class((interface, port), handler_class, 0)
-        except SSLError, e:
-            Logger().notifyChannel('xml-rpc-ssl', LOG_CRITICAL, "Can't load the certificate and/or the private key files")
-            sys.exit(1)
-        except Exception, e:
-            Logger().notifyChannel('xml-rpc', LOG_CRITICAL, "Error occur when strarting the server daemon: %s" % (e,))
-            sys.exit(1)
+        Servers here must have threaded behaviour. start() must not block,
+        there is no run().
+    """
+    __is_started = False
+    __servers = []
 
 
-    def attach(self, path, gw):
-        pass
+    __logger = logging.getLogger('server')
 
-    def stop(self):
-        self.running = False
-        if os.name != 'nt':
-            self.server.socket.shutdown( hasattr(socket, 'SHUT_RDWR') and socket.SHUT_RDWR or 2 )
-        self.server.socket.close()
-
-    def run(self):
-        self.server.register_introspection_functions()
-
-        self.running = True
-        while self.running:
-            self.server.handle_request()
-        return True
-
-        # If the server need to be run recursively
-        #
-        #signal.signal(signal.SIGALRM, self.my_handler)
-        #signal.alarm(6)
-        #while True:
-        #   self.server.handle_request()
-        #signal.alarm(0)          # Disable the alarm
-
-import tiny_socket
-class TinySocketClientThread(threading.Thread):
-    def __init__(self, sock, threads):
-        threading.Thread.__init__(self)
-        self.sock = sock
-        self.threads = threads
-        self._logger = Logger()
-
-    def log(self, msg):
-        self._logger.notifyChannel('NETRPC', LOG_DEBUG_RPC, msg)
-
-    def run(self):
-        import traceback
-        import time
-        import select
-        self.running = True
-        try:
-            ts = tiny_socket.mysocket(self.sock)
-        except:
-            self.sock.close()
-            self.threads.remove(self)
-            return False
-        while self.running:
-            try:
-                msg = ts.myreceive()
-            except:
-                self.sock.close()
-                self.threads.remove(self)
-                return False
-            try:
-                self.log(msg)
-                service = LocalService(msg[0])
-                method = getattr(service, msg[1])
-                service._service._response = None
-                result_from_method = method(*msg[2:])
-                res = service._service._response
-                if res != None:
-                    result_from_method = res
-                self.log(result_from_method)
-                ts.mysend(result_from_method)
-            except Exception, e:
-                print repr(e)
-                tb_s = reduce(lambda x, y: x+y, traceback.format_exception(sys.exc_type, sys.exc_value, sys.exc_traceback))
-                import tools
-                if tools.config['debug_mode']:
-                    import pdb
-                    tb = sys.exc_info()[2]
-                    pdb.post_mortem(tb)
-                e = Exception(str(e))
-                self.log(str(e))
-                ts.mysend(e, exception=True, traceback=tb_s)
-            except:
-                pass
-            self.sock.close()
-            self.threads.remove(self)
-            return True
+    def __init__(self):
+        if Server.__is_started:
+            raise Exception('All instances of servers must be inited before the startAll()')
+        Server.__servers.append(self)
 
-    def stop(self):
-        self.running = False
-
-
-class TinySocketServerThread(threading.Thread):
-    def __init__(self, interface, port, secure=False):
-        threading.Thread.__init__(self)
-        self.__port = port
-        self.__interface = interface
-        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        self.socket.bind((self.__interface, self.__port))
-        self.socket.listen(5)
-        self.threads = []
-
-    def run(self):
-        import select
-        try:
-            self.running = True
-            while self.running:
-                (clientsocket, address) = self.socket.accept()
-                ct = TinySocketClientThread(clientsocket, self.threads)
-                self.threads.append(ct)
-                ct.start()
-            self.socket.close()
-        except Exception, e:
-            self.socket.close()
-            return False
+    def start(self):
+        self.__logger.debug("called stub Server.start")
 
     def stop(self):
-        self.running = False
-        for t in self.threads:
-            t.stop()
-        try:
-            if hasattr(socket, 'SHUT_RDWR'):
-                self.socket.shutdown(socket.SHUT_RDWR)
-            else:
-                self.socket.shutdown(2)
-            self.socket.close()
-        except:
-            return False
-
-
+        self.__logger.debug("called stub Server.stop")
+
+    def stats(self):
+        """ This function should return statistics about the server """
+        return "%s: No statistics" % str(self.__class__)
+
+    @classmethod
+    def startAll(cls):
+        if cls.__is_started:
+            return
+        cls.__logger.info("Starting %d services" % len(cls.__servers))
+        for srv in cls.__servers:
+            srv.start()
+        cls.__is_started = True
+
+    @classmethod
+    def quitAll(cls):
+        if not cls.__is_started:
+            return
+        cls.__logger.info("Stopping %d services" % len(cls.__servers))
+        for srv in cls.__servers:
+            srv.stop()
+        cls.__is_started = False
+
+    @classmethod
+    def allStats(cls):
+        res = ["Servers %s" % ('stopped', 'started')[cls.__is_started]]
+        res.extend(srv.stats() for srv in cls.__servers)
+        return '\n'.join(res)
+
+class OpenERPDispatcherException(Exception):
+    def __init__(self, exception, traceback):
+        self.exception = exception
+        self.traceback = traceback
+
+class OpenERPDispatcher:
+    def log(self, title, msg):
+        Logger().notifyChannel('%s' % title, LOG_DEBUG_RPC, pformat(msg))
 
+    def dispatch(self, service_name, method, params):
+        try:
+            self.log('service', service_name)
+            self.log('method', method)
+            self.log('params', params)
+            auth = getattr(self, 'auth_provider', None)
+            result = ExportService.getService(service_name).dispatch(method, auth, params)
+            self.log('result', result)
+            # We shouldn't marshall None,
+            if result == None:
+                result = False
+            return result
+        except Exception, e:
+            self.log('exception', tools.exception_to_unicode(e))
+            tb = getattr(e, 'traceback', sys.exc_info())
+            tb_s = "".join(traceback.format_exception(*tb))
+            if tools.config['debug_mode']:
+                import pdb
+                pdb.post_mortem(tb[2])
+            raise OpenERPDispatcherException(e, tb_s)
 
 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:
-