[MERGE] OSX handling of socket shutdown.
[odoo/odoo.git] / openerp / netsvc.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 ##############################################################################
4 #
5 #    OpenERP, Open Source Management Solution
6 #    Copyright (C) 2004-2009 Tiny SPRL (<http://tiny.be>). All Rights Reserved
7 #    The refactoring about the OpenSSL support come from Tryton
8 #    Copyright (C) 2007-2009 Cédric Krier.
9 #    Copyright (C) 2007-2009 Bertrand Chenal.
10 #    Copyright (C) 2008 B2CK SPRL.
11 #
12 #    This program is free software: you can redistribute it and/or modify
13 #    it under the terms of the GNU General Public License as published by
14 #    the Free Software Foundation, either version 3 of the License, or
15 #    (at your option) any later version.
16 #
17 #    This program is distributed in the hope that it will be useful,
18 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
19 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20 #    GNU General Public License for more details.
21 #
22 #    You should have received a copy of the GNU General Public License
23 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
24 #
25 ##############################################################################
26
27 import cgitb
28 import errno
29 import heapq
30 import logging
31 import logging.handlers
32 import os
33 import platform
34 import release
35 import socket
36 import sys
37 import threading
38 import time
39 import warnings
40 from pprint import pformat
41
42 # TODO modules that import netsvc only for things from loglevels must be changed to use loglevels.
43 from loglevels import *
44 import tools
45
46 def close_socket(sock):
47     """ Closes a socket instance cleanly
48
49     :param sock: the network socket to close
50     :type sock: socket.socket
51     """
52     try:
53         sock.shutdown(socket.SHUT_RDWR)
54     except socket.error, e:
55         # On OSX, socket shutdowns both sides if any side closes it
56         # causing an error 57 'Socket is not connected' on shutdown
57         # of the other side (or something), see
58         # http://bugs.python.org/issue4397
59         # note: stdlib fixed test, not behavior
60         if e.errno != errno.ENOTCONN or platform.system() != 'Darwin':
61             raise
62     sock.close()
63
64
65 class Service(object):
66     """ Base class for *Local* services
67
68         Functionality here is trusted, no authentication.
69     """
70     _services = {}
71     def __init__(self, name, audience=''):
72         Service._services[name] = self
73         self.__name = name
74         self._methods = {}
75
76     def joinGroup(self, name):
77         raise Exception("No group for local services")
78         #GROUPS.setdefault(name, {})[self.__name] = self
79
80     @classmethod
81     def exists(cls, name):
82         return name in cls._services
83
84     @classmethod
85     def remove(cls, name):
86         if cls.exists(name):
87             cls._services.pop(name)
88
89     def exportMethod(self, method):
90         if callable(method):
91             self._methods[method.__name__] = method
92
93     def abortResponse(self, error, description, origin, details):
94         if not tools.config['debug_mode']:
95             raise Exception("%s -- %s\n\n%s"%(origin, description, details))
96         else:
97             raise
98
99 class LocalService(object):
100     """ Proxy for local services. 
101     
102         Any instance of this class will behave like the single instance
103         of Service(name)
104     """
105     __logger = logging.getLogger('service')
106     def __init__(self, name):
107         self.__name = name
108         try:
109             self._service = Service._services[name]
110             for method_name, method_definition in self._service._methods.items():
111                 setattr(self, method_name, method_definition)
112         except KeyError, keyError:
113             self.__logger.error('This service does not exist: %s' % (str(keyError),) )
114             raise
115
116     def __call__(self, method, *params):
117         return getattr(self, method)(*params)
118
119 class ExportService(object):
120     """ Proxy for exported services. 
121
122     All methods here should take an AuthProxy as their first parameter. It
123     will be appended by the calling framework.
124
125     Note that this class has no direct proxy, capable of calling 
126     eservice.method(). Rather, the proxy should call 
127     dispatch(method,auth,params)
128     """
129     
130     _services = {}
131     _groups = {}
132     _logger = logging.getLogger('web-services')
133     
134     def __init__(self, name, audience=''):
135         ExportService._services[name] = self
136         self.__name = name
137         self._logger.debug("Registered an exported service: %s" % name)
138
139     def joinGroup(self, name):
140         ExportService._groups.setdefault(name, {})[self.__name] = self
141
142     @classmethod
143     def getService(cls,name):
144         return cls._services[name]
145
146     def dispatch(self, method, auth, params):
147         raise Exception("stub dispatch at %s" % self.__name)
148         
149     def new_dispatch(self,method,auth,params):
150         raise Exception("stub dispatch at %s" % self.__name)
151
152     def abortResponse(self, error, description, origin, details):
153         if not tools.config['debug_mode']:
154             raise Exception("%s -- %s\n\n%s"%(origin, description, details))
155         else:
156             raise
157
158 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, _NOTHING, DEFAULT = range(10)
159 #The background is set with 40 plus the number of the color, and the foreground with 30
160 #These are the sequences need to get colored ouput
161 RESET_SEQ = "\033[0m"
162 COLOR_SEQ = "\033[1;%dm"
163 BOLD_SEQ = "\033[1m"
164 COLOR_PATTERN = "%s%s%%s%s" % (COLOR_SEQ, COLOR_SEQ, RESET_SEQ)
165 LEVEL_COLOR_MAPPING = {
166     logging.DEBUG_SQL: (WHITE, MAGENTA),
167     logging.DEBUG_RPC: (BLUE, WHITE),
168     logging.DEBUG_RPC_ANSWER: (BLUE, WHITE),
169     logging.DEBUG: (BLUE, DEFAULT),
170     logging.INFO: (GREEN, DEFAULT),
171     logging.TEST: (WHITE, BLUE),
172     logging.WARNING: (YELLOW, DEFAULT),
173     logging.ERROR: (RED, DEFAULT),
174     logging.CRITICAL: (WHITE, RED),
175 }
176
177 class DBFormatter(logging.Formatter):
178     def format(self, record):
179         record.dbname = getattr(threading.currentThread(), 'dbname', '?')
180         return logging.Formatter.format(self, record)
181
182 class ColoredFormatter(DBFormatter):
183     def format(self, record):
184         fg_color, bg_color = LEVEL_COLOR_MAPPING[record.levelno]
185         record.levelname = COLOR_PATTERN % (30 + fg_color, 40 + bg_color, record.levelname)
186         return DBFormatter.format(self, record)
187
188 def init_logger():
189     import os
190     from tools.translate import resetlocale
191     resetlocale()
192
193     # create a format for log messages and dates
194     format = '[%(asctime)s][%(dbname)s] %(levelname)s:%(name)s:%(message)s'
195
196     if tools.config['syslog']:
197         # SysLog Handler
198         if os.name == 'nt':
199             handler = logging.handlers.NTEventLogHandler("%s %s" % (release.description, release.version))
200         else:
201             handler = logging.handlers.SysLogHandler('/dev/log')
202         format = '%s %s' % (release.description, release.version) \
203                 + ':%(dbname)s:%(levelname)s:%(name)s:%(message)s'
204
205     elif tools.config['logfile']:
206         # LogFile Handler
207         logf = tools.config['logfile']
208         try:
209             dirname = os.path.dirname(logf)
210             if dirname and not os.path.isdir(dirname):
211                 os.makedirs(dirname)
212             if tools.config['logrotate'] is not False:
213                 handler = logging.handlers.TimedRotatingFileHandler(logf,'D',1,30)
214             elif os.name == 'posix':
215                 handler = logging.handlers.WatchedFileHandler(logf)
216             else:
217                 handler = logging.handlers.FileHandler(logf)
218         except Exception:
219             sys.stderr.write("ERROR: couldn't create the logfile directory. Logging to the standard output.\n")
220             handler = logging.StreamHandler(sys.stdout)
221     else:
222         # Normal Handler on standard output
223         handler = logging.StreamHandler(sys.stdout)
224
225     if isinstance(handler, logging.StreamHandler) and os.isatty(handler.stream.fileno()):
226         formatter = ColoredFormatter(format)
227     else:
228         formatter = DBFormatter(format)
229     handler.setFormatter(formatter)
230
231     # add the handler to the root logger
232     logger = logging.getLogger()
233     logger.handlers = []
234     logger.addHandler(handler)
235     logger.setLevel(int(tools.config['log_level'] or '0'))
236
237 # A alternative logging scheme for automated runs of the
238 # server intended to test it.
239 def init_alternative_logger():
240     class H(logging.Handler):
241       def emit(self, record):
242         if record.levelno > 20:
243           print record.levelno, record.pathname, record.msg
244     handler = H()
245     logger = logging.getLogger()
246     logger.handlers = []
247     logger.addHandler(handler)
248     logger.setLevel(logging.ERROR)
249
250 class Agent(object):
251     """Singleton that keeps track of cancellable tasks to run at a given
252        timestamp.
253        The tasks are caracterised by:
254             * a timestamp
255             * the database on which the task run
256             * the function to call
257             * the arguments and keyword arguments to pass to the function
258
259         Implementation details:
260           Tasks are stored as list, allowing the cancellation by setting
261           the timestamp to 0.
262           A heapq is used to store tasks, so we don't need to sort
263           tasks ourself.
264     """
265     __tasks = []
266     __tasks_by_db = {}
267     _logger = logging.getLogger('netsvc.agent')
268
269     @classmethod
270     def setAlarm(cls, function, timestamp, db_name, *args, **kwargs):
271         task = [timestamp, db_name, function, args, kwargs]
272         heapq.heappush(cls.__tasks, task)
273         cls.__tasks_by_db.setdefault(db_name, []).append(task)
274
275     @classmethod
276     def cancel(cls, db_name):
277         """Cancel all tasks for a given database. If None is passed, all tasks are cancelled"""
278         cls._logger.debug("Cancel timers for %s db", db_name or 'all')
279         if db_name is None:
280             cls.__tasks, cls.__tasks_by_db = [], {}
281         else:
282             if db_name in cls.__tasks_by_db:
283                 for task in cls.__tasks_by_db[db_name]:
284                     task[0] = 0
285
286     @classmethod
287     def quit(cls):
288         cls.cancel(None)
289
290     @classmethod
291     def runner(cls):
292         """Neverending function (intended to be ran in a dedicated thread) that
293            checks every 60 seconds tasks to run. TODO: make configurable
294         """
295         current_thread = threading.currentThread()
296         while True:
297             while cls.__tasks and cls.__tasks[0][0] < time.time():
298                 task = heapq.heappop(cls.__tasks)
299                 timestamp, dbname, function, args, kwargs = task
300                 cls.__tasks_by_db[dbname].remove(task)
301                 if not timestamp:
302                     # null timestamp -> cancelled task
303                     continue
304                 current_thread.dbname = dbname   # hack hack
305                 cls._logger.debug("Run %s.%s(*%s, **%s)", function.im_class.__name__, function.func_name, args, kwargs)
306                 delattr(current_thread, 'dbname')
307                 task_thread = threading.Thread(target=function, name='netsvc.Agent.task', args=args, kwargs=kwargs)
308                 # force non-daemon task threads (the runner thread must be daemon, and this property is inherited by default)
309                 task_thread.setDaemon(False)
310                 task_thread.start()
311                 time.sleep(1)
312             time.sleep(60)
313
314 def start_agent():
315     agent_runner = threading.Thread(target=Agent.runner, name="netsvc.Agent.runner")
316     # the agent runner is a typical daemon thread, that will never quit and must be
317     # terminated when the main process exits - with no consequence (the processing
318     # threads it spawns are not marked daemon)
319     agent_runner.setDaemon(True)
320     agent_runner.start()
321
322 import traceback
323
324 class Server:
325     """ Generic interface for all servers with an event loop etc.
326         Override this to impement http, net-rpc etc. servers.
327
328         Servers here must have threaded behaviour. start() must not block,
329         there is no run().
330     """
331     __is_started = False
332     __servers = []
333     __starter_threads = []
334
335     # we don't want blocking server calls (think select()) to
336     # wait forever and possibly prevent exiting the process,
337     # but instead we want a form of polling/busy_wait pattern, where
338     # _server_timeout should be used as the default timeout for
339     # all I/O blocking operations
340     _busywait_timeout = 0.5
341
342
343     __logger = logging.getLogger('server')
344
345     def __init__(self):
346         Server.__servers.append(self)
347         if Server.__is_started:
348             # raise Exception('All instances of servers must be inited before the startAll()')
349             # Since the startAll() won't be called again, allow this server to
350             # init and then start it after 1sec (hopefully). Register that
351             # timer thread in a list, so that we can abort the start if quitAll
352             # is called in the meantime
353             t = threading.Timer(1.0, self._late_start)
354             t.name = 'Late start timer for %s' % str(self.__class__)
355             Server.__starter_threads.append(t)
356             t.start()
357
358     def start(self):
359         self.__logger.debug("called stub Server.start")
360         
361     def _late_start(self):
362         self.start()
363         for thr in Server.__starter_threads:
364             if thr.finished.is_set():
365                 Server.__starter_threads.remove(thr)
366
367     def stop(self):
368         self.__logger.debug("called stub Server.stop")
369
370     def stats(self):
371         """ This function should return statistics about the server """
372         return "%s: No statistics" % str(self.__class__)
373
374     @classmethod
375     def startAll(cls):
376         if cls.__is_started:
377             return
378         cls.__logger.info("Starting %d services" % len(cls.__servers))
379         for srv in cls.__servers:
380             srv.start()
381         cls.__is_started = True
382
383     @classmethod
384     def quitAll(cls):
385         if not cls.__is_started:
386             return
387         cls.__logger.info("Stopping %d services" % len(cls.__servers))
388         for thr in cls.__starter_threads:
389             if not thr.finished.is_set():
390                 thr.cancel()
391             cls.__starter_threads.remove(thr)
392
393         for srv in cls.__servers:
394             srv.stop()
395         cls.__is_started = False
396
397     @classmethod
398     def allStats(cls):
399         res = ["Servers %s" % ('stopped', 'started')[cls.__is_started]]
400         res.extend(srv.stats() for srv in cls.__servers)
401         return '\n'.join(res)
402
403     def _close_socket(self):
404         close_socket(self.socket)
405
406 class OpenERPDispatcherException(Exception):
407     def __init__(self, exception, traceback):
408         self.exception = exception
409         self.traceback = traceback
410
411 def replace_request_password(args):
412     # password is always 3rd argument in a request, we replace it in RPC logs
413     # so it's easier to forward logs for diagnostics/debugging purposes...
414     args = list(args)
415     if len(args) > 2:
416         args[2] = '*'
417     return args
418
419 def log(title, msg, channel=logging.DEBUG_RPC, depth=None, fn=""):
420     logger = logging.getLogger(title)
421     if logger.isEnabledFor(channel):
422         indent=''
423         indent_after=' '*len(fn)
424         for line in (fn+pformat(msg, depth=depth)).split('\n'):
425             logger.log(channel, indent+line)
426             indent=indent_after
427
428 class OpenERPDispatcher:
429     def log(self, title, msg, channel=logging.DEBUG_RPC, depth=None, fn=""):
430         log(title, msg, channel=channel, depth=depth, fn=fn)
431     def dispatch(self, service_name, method, params):
432         try:
433             logger = logging.getLogger('result')
434             self.log('service', tuple(replace_request_password(params)), depth=(None if logger.isEnabledFor(logging.DEBUG_RPC_ANSWER) else 1), fn='%s.%s'%(service_name,method))
435             auth = getattr(self, 'auth_provider', None)
436             result = ExportService.getService(service_name).dispatch(method, auth, params)
437             self.log('result', result, channel=logging.DEBUG_RPC_ANSWER)
438             return result
439         except Exception, e:
440             self.log('exception', tools.exception_to_unicode(e))
441             tb = getattr(e, 'traceback', sys.exc_info())
442             tb_s = cgitb.text(tb)
443             if tools.config['debug_mode']:
444                 import pdb
445                 pdb.post_mortem(tb[2])
446             raise OpenERPDispatcherException(e, tb_s)
447
448 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: