81a97f124d15746dc3f558f1eac400d8f4be9583
[odoo/odoo.git] / openerp / netsvc.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 ##############################################################################
4 #
5 #    OpenERP, Open Source Management Solution
6 #    Copyright (C) 2004-2011 OpenERP SA (<http://www.openerp.com>)
7 #
8 #    This program is free software: you can redistribute it and/or modify
9 #    it under the terms of the GNU Affero General Public License as
10 #    published by the Free Software Foundation, either version 3 of the
11 #    License, or (at your option) any later version.
12 #
13 #    This program is distributed in the hope that it will be useful,
14 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
15 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 #    GNU Affero General Public License for more details.
17 #
18 #    You should have received a copy of the GNU Affero General Public License
19 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 #
21 ##############################################################################
22
23 import errno
24 import heapq
25 import logging
26 import logging.handlers
27 import os
28 import platform
29 import release
30 import socket
31 import sys
32 import threading
33 import time
34 import types
35 from pprint import pformat
36
37 # TODO modules that import netsvc only for things from loglevels must be changed to use loglevels.
38 from loglevels import *
39 import tools
40
41 def close_socket(sock):
42     """ Closes a socket instance cleanly
43
44     :param sock: the network socket to close
45     :type sock: socket.socket
46     """
47     try:
48         sock.shutdown(socket.SHUT_RDWR)
49     except socket.error, e:
50         # On OSX, socket shutdowns both sides if any side closes it
51         # causing an error 57 'Socket is not connected' on shutdown
52         # of the other side (or something), see
53         # http://bugs.python.org/issue4397
54         # note: stdlib fixed test, not behavior
55         if e.errno != errno.ENOTCONN or platform.system() != 'Darwin':
56             raise
57     sock.close()
58
59
60 #.apidoc title: Common Services: netsvc
61 #.apidoc module-mods: member-order: bysource
62
63 class Service(object):
64     """ Base class for *Local* services
65
66         Functionality here is trusted, no authentication.
67     """
68     _services = {}
69     def __init__(self, name, audience=''):
70         Service._services[name] = self
71         self.__name = name
72         self._methods = {}
73
74     def joinGroup(self, name):
75         raise Exception("No group for local services")
76         #GROUPS.setdefault(name, {})[self.__name] = self
77
78     @classmethod
79     def exists(cls, name):
80         return name in cls._services
81
82     @classmethod
83     def remove(cls, name):
84         if cls.exists(name):
85             cls._services.pop(name)
86
87     def exportMethod(self, method):
88         if callable(method):
89             self._methods[method.__name__] = method
90
91     def abortResponse(self, error, description, origin, details):
92         if not tools.config['debug_mode']:
93             raise Exception("%s -- %s\n\n%s"%(origin, description, details))
94         else:
95             raise
96
97 def LocalService(name):
98   return Service._services[name]
99
100 class ExportService(object):
101     """ Proxy for exported services. 
102
103     All methods here should take an AuthProxy as their first parameter. It
104     will be appended by the calling framework.
105
106     Note that this class has no direct proxy, capable of calling 
107     eservice.method(). Rather, the proxy should call 
108     dispatch(method,auth,params)
109     """
110     
111     _services = {}
112     _groups = {}
113     _logger = logging.getLogger('web-services')
114     
115     def __init__(self, name, audience=''):
116         ExportService._services[name] = self
117         self.__name = name
118         self._logger.debug("Registered an exported service: %s" % name)
119
120     def joinGroup(self, name):
121         ExportService._groups.setdefault(name, {})[self.__name] = self
122
123     @classmethod
124     def getService(cls,name):
125         return cls._services[name]
126
127     def dispatch(self, method, auth, params):
128         raise Exception("stub dispatch at %s" % self.__name)
129         
130     def abortResponse(self, error, description, origin, details):
131         if not tools.config['debug_mode']:
132             raise Exception("%s -- %s\n\n%s"%(origin, description, details))
133         else:
134             raise
135
136 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, _NOTHING, DEFAULT = range(10)
137 #The background is set with 40 plus the number of the color, and the foreground with 30
138 #These are the sequences need to get colored ouput
139 RESET_SEQ = "\033[0m"
140 COLOR_SEQ = "\033[1;%dm"
141 BOLD_SEQ = "\033[1m"
142 COLOR_PATTERN = "%s%s%%s%s" % (COLOR_SEQ, COLOR_SEQ, RESET_SEQ)
143 LEVEL_COLOR_MAPPING = {
144     logging.DEBUG_SQL: (WHITE, MAGENTA),
145     logging.DEBUG_RPC: (BLUE, WHITE),
146     logging.DEBUG_RPC_ANSWER: (BLUE, WHITE),
147     logging.DEBUG: (BLUE, DEFAULT),
148     logging.INFO: (GREEN, DEFAULT),
149     logging.TEST: (WHITE, BLUE),
150     logging.WARNING: (YELLOW, DEFAULT),
151     logging.ERROR: (RED, DEFAULT),
152     logging.CRITICAL: (WHITE, RED),
153 }
154
155 class DBFormatter(logging.Formatter):
156     def format(self, record):
157         record.dbname = getattr(threading.currentThread(), 'dbname', '?')
158         return logging.Formatter.format(self, record)
159
160 class ColoredFormatter(DBFormatter):
161     def format(self, record):
162         fg_color, bg_color = LEVEL_COLOR_MAPPING[record.levelno]
163         record.levelname = COLOR_PATTERN % (30 + fg_color, 40 + bg_color, record.levelname)
164         return DBFormatter.format(self, record)
165
166 def init_logger():
167     from tools.translate import resetlocale
168     resetlocale()
169
170     # create a format for log messages and dates
171     format = '[%(asctime)s][%(dbname)s] %(levelname)s:%(name)s:%(message)s'
172
173     if tools.config['syslog']:
174         # SysLog Handler
175         if os.name == 'nt':
176             handler = logging.handlers.NTEventLogHandler("%s %s" % (release.description, release.version))
177         else:
178             handler = logging.handlers.SysLogHandler('/dev/log')
179         format = '%s %s' % (release.description, release.version) \
180                 + ':%(dbname)s:%(levelname)s:%(name)s:%(message)s'
181
182     elif tools.config['logfile']:
183         # LogFile Handler
184         logf = tools.config['logfile']
185         try:
186             dirname = os.path.dirname(logf)
187             if dirname and not os.path.isdir(dirname):
188                 os.makedirs(dirname)
189             if tools.config['logrotate'] is not False:
190                 handler = logging.handlers.TimedRotatingFileHandler(logf,'D',1,30)
191             elif os.name == 'posix':
192                 handler = logging.handlers.WatchedFileHandler(logf)
193             else:
194                 handler = logging.handlers.FileHandler(logf)
195         except Exception:
196             sys.stderr.write("ERROR: couldn't create the logfile directory. Logging to the standard output.\n")
197             handler = logging.StreamHandler(sys.stdout)
198     else:
199         # Normal Handler on standard output
200         handler = logging.StreamHandler(sys.stdout)
201
202     if isinstance(handler, logging.StreamHandler) and os.isatty(handler.stream.fileno()):
203         formatter = ColoredFormatter(format)
204     else:
205         formatter = DBFormatter(format)
206     handler.setFormatter(formatter)
207
208     # add the handler to the root logger
209     logger = logging.getLogger()
210     logger.handlers = []
211     logger.addHandler(handler)
212     logger.setLevel(int(tools.config['log_level'] or '0'))
213
214 # A alternative logging scheme for automated runs of the
215 # server intended to test it.
216 def init_alternative_logger():
217     class H(logging.Handler):
218       def emit(self, record):
219         if record.levelno > 20:
220           print record.levelno, record.pathname, record.msg
221     handler = H()
222     logger = logging.getLogger()
223     logger.handlers = []
224     logger.addHandler(handler)
225     logger.setLevel(logging.ERROR)
226
227 class Agent(object):
228     """ Singleton that keeps track of cancellable tasks to run at a given
229         timestamp.
230        
231         The tasks are characterised by:
232        
233             * a timestamp
234             * the database on which the task run
235             * the function to call
236             * the arguments and keyword arguments to pass to the function
237
238         Implementation details:
239         
240           - Tasks are stored as list, allowing the cancellation by setting
241             the timestamp to 0.
242           - A heapq is used to store tasks, so we don't need to sort
243             tasks ourself.
244     """
245     __tasks = []
246     __tasks_by_db = {}
247     _logger = logging.getLogger('netsvc.agent')
248
249     @classmethod
250     def setAlarm(cls, function, timestamp, db_name, *args, **kwargs):
251         task = [timestamp, db_name, function, args, kwargs]
252         heapq.heappush(cls.__tasks, task)
253         cls.__tasks_by_db.setdefault(db_name, []).append(task)
254
255     @classmethod
256     def cancel(cls, db_name):
257         """Cancel all tasks for a given database. If None is passed, all tasks are cancelled"""
258         cls._logger.debug("Cancel timers for %s db", db_name or 'all')
259         if db_name is None:
260             cls.__tasks, cls.__tasks_by_db = [], {}
261         else:
262             if db_name in cls.__tasks_by_db:
263                 for task in cls.__tasks_by_db[db_name]:
264                     task[0] = 0
265
266     @classmethod
267     def quit(cls):
268         cls.cancel(None)
269
270     @classmethod
271     def runner(cls):
272         """Neverending function (intended to be ran in a dedicated thread) that
273            checks every 60 seconds tasks to run. TODO: make configurable
274         """
275         current_thread = threading.currentThread()
276         while True:
277             while cls.__tasks and cls.__tasks[0][0] < time.time():
278                 task = heapq.heappop(cls.__tasks)
279                 timestamp, dbname, function, args, kwargs = task
280                 cls.__tasks_by_db[dbname].remove(task)
281                 if not timestamp:
282                     # null timestamp -> cancelled task
283                     continue
284                 current_thread.dbname = dbname   # hack hack
285                 cls._logger.debug("Run %s.%s(*%s, **%s)", function.im_class.__name__, function.func_name, args, kwargs)
286                 delattr(current_thread, 'dbname')
287                 task_thread = threading.Thread(target=function, name='netsvc.Agent.task', args=args, kwargs=kwargs)
288                 # force non-daemon task threads (the runner thread must be daemon, and this property is inherited by default)
289                 task_thread.setDaemon(False)
290                 task_thread.start()
291                 time.sleep(1)
292             time.sleep(60)
293
294 def start_agent():
295     agent_runner = threading.Thread(target=Agent.runner, name="netsvc.Agent.runner")
296     # the agent runner is a typical daemon thread, that will never quit and must be
297     # terminated when the main process exits - with no consequence (the processing
298     # threads it spawns are not marked daemon)
299     agent_runner.setDaemon(True)
300     agent_runner.start()
301
302 import traceback
303
304 class Server:
305     """ Generic interface for all servers with an event loop etc.
306         Override this to impement http, net-rpc etc. servers.
307
308         Servers here must have threaded behaviour. start() must not block,
309         there is no run().
310     """
311     __is_started = False
312     __servers = []
313     __starter_threads = []
314
315     # we don't want blocking server calls (think select()) to
316     # wait forever and possibly prevent exiting the process,
317     # but instead we want a form of polling/busy_wait pattern, where
318     # _server_timeout should be used as the default timeout for
319     # all I/O blocking operations
320     _busywait_timeout = 0.5
321
322
323     __logger = logging.getLogger('server')
324
325     def __init__(self):
326         Server.__servers.append(self)
327         if Server.__is_started:
328             # raise Exception('All instances of servers must be inited before the startAll()')
329             # Since the startAll() won't be called again, allow this server to
330             # init and then start it after 1sec (hopefully). Register that
331             # timer thread in a list, so that we can abort the start if quitAll
332             # is called in the meantime
333             t = threading.Timer(1.0, self._late_start)
334             t.name = 'Late start timer for %s' % str(self.__class__)
335             Server.__starter_threads.append(t)
336             t.start()
337
338     def start(self):
339         self.__logger.debug("called stub Server.start")
340         
341     def _late_start(self):
342         self.start()
343         for thr in Server.__starter_threads:
344             if thr.finished.is_set():
345                 Server.__starter_threads.remove(thr)
346
347     def stop(self):
348         self.__logger.debug("called stub Server.stop")
349
350     def stats(self):
351         """ This function should return statistics about the server """
352         return "%s: No statistics" % str(self.__class__)
353
354     @classmethod
355     def startAll(cls):
356         if cls.__is_started:
357             return
358         cls.__logger.info("Starting %d services" % len(cls.__servers))
359         for srv in cls.__servers:
360             srv.start()
361         cls.__is_started = True
362
363     @classmethod
364     def quitAll(cls):
365         if not cls.__is_started:
366             return
367         cls.__logger.info("Stopping %d services" % len(cls.__servers))
368         for thr in cls.__starter_threads:
369             if not thr.finished.is_set():
370                 thr.cancel()
371             cls.__starter_threads.remove(thr)
372
373         for srv in cls.__servers:
374             srv.stop()
375         cls.__is_started = False
376
377     @classmethod
378     def allStats(cls):
379         res = ["Servers %s" % ('stopped', 'started')[cls.__is_started]]
380         res.extend(srv.stats() for srv in cls.__servers)
381         return '\n'.join(res)
382
383     def _close_socket(self):
384         close_socket(self.socket)
385
386 class OpenERPDispatcherException(Exception):
387     def __init__(self, exception, traceback):
388         self.exception = exception
389         self.traceback = traceback
390
391 def replace_request_password(args):
392     # password is always 3rd argument in a request, we replace it in RPC logs
393     # so it's easier to forward logs for diagnostics/debugging purposes...
394     args = list(args)
395     if len(args) > 2:
396         args[2] = '*'
397     return args
398
399 def log(title, msg, channel=logging.DEBUG_RPC, depth=None, fn=""):
400     logger = logging.getLogger(title)
401     if logger.isEnabledFor(channel):
402         indent=''
403         indent_after=' '*len(fn)
404         for line in (fn+pformat(msg, depth=depth)).split('\n'):
405             logger.log(channel, indent+line)
406             indent=indent_after
407
408 class OpenERPDispatcher:
409     def log(self, title, msg, channel=logging.DEBUG_RPC, depth=None, fn=""):
410         log(title, msg, channel=channel, depth=depth, fn=fn)
411     def dispatch(self, service_name, method, params):
412         try:
413             auth = getattr(self, 'auth_provider', None)
414             logger = logging.getLogger('result')
415             start_time = end_time = 0
416             if logger.isEnabledFor(logging.DEBUG_RPC_ANSWER):
417                 self.log('service', tuple(replace_request_password(params)), depth=None, fn='%s.%s'%(service_name,method))
418             if logger.isEnabledFor(logging.DEBUG_RPC):
419                 start_time = time.time()
420             result = ExportService.getService(service_name).dispatch(method, auth, params)
421             if logger.isEnabledFor(logging.DEBUG_RPC):
422                 end_time = time.time()
423             if not logger.isEnabledFor(logging.DEBUG_RPC_ANSWER):
424                 self.log('service (%.3fs)' % (end_time - start_time), tuple(replace_request_password(params)), depth=1, fn='%s.%s'%(service_name,method))
425             self.log('execution time', '%.3fs' % (end_time - start_time), channel=logging.DEBUG_RPC_ANSWER)
426             self.log('result', result, channel=logging.DEBUG_RPC_ANSWER)
427             return result
428         except Exception, e:
429             self.log('exception', tools.exception_to_unicode(e))
430             tb = getattr(e, 'traceback', sys.exc_info())
431             tb_s = "".join(traceback.format_exception(*tb))
432             if tools.config['debug_mode'] and isinstance(tb, types.TracebackType):
433                 import pdb
434                 pdb.post_mortem(tb[2])
435             raise OpenERPDispatcherException(e, tb_s)
436
437 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: