0892d47c646fd63cf4f651866e49fe8a29c0e2e0
[odoo/odoo.git] / openerp / netsvc.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 ##############################################################################
4 #
5 #    OpenERP, Open Source Management Solution
6 #    Copyright (C) 2004-2011 OpenERP SA (<http://www.openerp.com>)
7 #
8 #    This program is free software: you can redistribute it and/or modify
9 #    it under the terms of the GNU Affero General Public License as
10 #    published by the Free Software Foundation, either version 3 of the
11 #    License, or (at your option) any later version.
12 #
13 #    This program is distributed in the hope that it will be useful,
14 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
15 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 #    GNU Affero General Public License for more details.
17 #
18 #    You should have received a copy of the GNU Affero General Public License
19 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 #
21 ##############################################################################
22
23 import errno
24 import heapq
25 import logging
26 import logging.handlers
27 import os
28 import platform
29 import release
30 import socket
31 import sys
32 import threading
33 import time
34 import types
35 from pprint import pformat
36
37 # TODO modules that import netsvc only for things from loglevels must be changed to use loglevels.
38 from loglevels import *
39 import tools
40
41 def close_socket(sock):
42     """ Closes a socket instance cleanly
43
44     :param sock: the network socket to close
45     :type sock: socket.socket
46     """
47     try:
48         sock.shutdown(socket.SHUT_RDWR)
49     except socket.error, e:
50         # On OSX, socket shutdowns both sides if any side closes it
51         # causing an error 57 'Socket is not connected' on shutdown
52         # of the other side (or something), see
53         # http://bugs.python.org/issue4397
54         # note: stdlib fixed test, not behavior
55         if e.errno != errno.ENOTCONN or platform.system() != 'Darwin':
56             raise
57     sock.close()
58
59
60 #.apidoc title: Common Services: netsvc
61 #.apidoc module-mods: member-order: bysource
62
63 def abort_response(error, description, origin, details):
64     if not tools.config['debug_mode']:
65         raise Exception("%s -- %s\n\n%s"%(origin, description, details))
66     else:
67         raise
68
69 class Service(object):
70     """ Base class for *Local* services
71
72         Functionality here is trusted, no authentication.
73     """
74     _services = {}
75     def __init__(self, name):
76         Service._services[name] = self
77         self.__name = name
78
79     @classmethod
80     def exists(cls, name):
81         return name in cls._services
82
83     @classmethod
84     def remove(cls, name):
85         if cls.exists(name):
86             cls._services.pop(name)
87
88 def LocalService(name):
89   # Special case for addons support, will be removed in a few days when addons
90   # are updated to directly use openerp.osv.osv.service.
91   if name == 'object_proxy':
92       return openerp.osv.osv.service
93
94   return Service._services[name]
95
96 class ExportService(object):
97     """ Proxy for exported services.
98
99     All methods here should take an AuthProxy as their first parameter. It
100     will be appended by the calling framework.
101
102     Note that this class has no direct proxy, capable of calling
103     eservice.method(). Rather, the proxy should call
104     dispatch(method,auth,params)
105     """
106
107     _services = {}
108     _logger = logging.getLogger('web-services')
109     
110     def __init__(self, name):
111         ExportService._services[name] = self
112         self.__name = name
113         self._logger.debug("Registered an exported service: %s" % name)
114
115     @classmethod
116     def getService(cls,name):
117         return cls._services[name]
118
119     # Dispatch a RPC call w.r.t. the method name. The dispatching
120     # w.r.t. the service (this class) is done by OpenERPDispatcher.
121     def dispatch(self, method, auth, params):
122         raise Exception("stub dispatch at %s" % self.__name)
123
124 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, _NOTHING, DEFAULT = range(10)
125 #The background is set with 40 plus the number of the color, and the foreground with 30
126 #These are the sequences need to get colored ouput
127 RESET_SEQ = "\033[0m"
128 COLOR_SEQ = "\033[1;%dm"
129 BOLD_SEQ = "\033[1m"
130 COLOR_PATTERN = "%s%s%%s%s" % (COLOR_SEQ, COLOR_SEQ, RESET_SEQ)
131 LEVEL_COLOR_MAPPING = {
132     logging.DEBUG_SQL: (WHITE, MAGENTA),
133     logging.DEBUG_RPC: (BLUE, WHITE),
134     logging.DEBUG_RPC_ANSWER: (BLUE, WHITE),
135     logging.DEBUG: (BLUE, DEFAULT),
136     logging.INFO: (GREEN, DEFAULT),
137     logging.TEST: (WHITE, BLUE),
138     logging.WARNING: (YELLOW, DEFAULT),
139     logging.ERROR: (RED, DEFAULT),
140     logging.CRITICAL: (WHITE, RED),
141 }
142
143 class DBFormatter(logging.Formatter):
144     def format(self, record):
145         record.dbname = getattr(threading.currentThread(), 'dbname', '?')
146         return logging.Formatter.format(self, record)
147
148 class ColoredFormatter(DBFormatter):
149     def format(self, record):
150         fg_color, bg_color = LEVEL_COLOR_MAPPING[record.levelno]
151         record.levelname = COLOR_PATTERN % (30 + fg_color, 40 + bg_color, record.levelname)
152         return DBFormatter.format(self, record)
153
154 def init_logger():
155     from tools.translate import resetlocale
156     resetlocale()
157
158     # create a format for log messages and dates
159     format = '[%(asctime)s][%(dbname)s] %(levelname)s:%(name)s:%(message)s'
160
161     if tools.config['syslog']:
162         # SysLog Handler
163         if os.name == 'nt':
164             handler = logging.handlers.NTEventLogHandler("%s %s" % (release.description, release.version))
165         else:
166             handler = logging.handlers.SysLogHandler('/dev/log')
167         format = '%s %s' % (release.description, release.version) \
168                 + ':%(dbname)s:%(levelname)s:%(name)s:%(message)s'
169
170     elif tools.config['logfile']:
171         # LogFile Handler
172         logf = tools.config['logfile']
173         try:
174             dirname = os.path.dirname(logf)
175             if dirname and not os.path.isdir(dirname):
176                 os.makedirs(dirname)
177             if tools.config['logrotate'] is not False:
178                 handler = logging.handlers.TimedRotatingFileHandler(logf,'D',1,30)
179             elif os.name == 'posix':
180                 handler = logging.handlers.WatchedFileHandler(logf)
181             else:
182                 handler = logging.handlers.FileHandler(logf)
183         except Exception:
184             sys.stderr.write("ERROR: couldn't create the logfile directory. Logging to the standard output.\n")
185             handler = logging.StreamHandler(sys.stdout)
186     else:
187         # Normal Handler on standard output
188         handler = logging.StreamHandler(sys.stdout)
189
190     if isinstance(handler, logging.StreamHandler) and os.isatty(handler.stream.fileno()):
191         formatter = ColoredFormatter(format)
192     else:
193         formatter = DBFormatter(format)
194     handler.setFormatter(formatter)
195
196     # add the handler to the root logger
197     logger = logging.getLogger()
198     logger.handlers = []
199     logger.addHandler(handler)
200     logger.setLevel(int(tools.config['log_level'] or '0'))
201
202 # A alternative logging scheme for automated runs of the
203 # server intended to test it.
204 def init_alternative_logger():
205     class H(logging.Handler):
206       def emit(self, record):
207         if record.levelno > 20:
208           print record.levelno, record.pathname, record.msg
209     handler = H()
210     logger = logging.getLogger()
211     logger.handlers = []
212     logger.addHandler(handler)
213     logger.setLevel(logging.ERROR)
214
215 class Agent(object):
216     """ Singleton that keeps track of cancellable tasks to run at a given
217         timestamp.
218
219         The tasks are characterised by:
220
221             * a timestamp
222             * the database on which the task run
223             * the function to call
224             * the arguments and keyword arguments to pass to the function
225
226         Implementation details:
227
228           - Tasks are stored as list, allowing the cancellation by setting
229             the timestamp to 0.
230           - A heapq is used to store tasks, so we don't need to sort
231             tasks ourself.
232     """
233     __tasks = []
234     __tasks_by_db = {}
235     _logger = logging.getLogger('netsvc.agent')
236
237     @classmethod
238     def setAlarm(cls, function, timestamp, db_name, *args, **kwargs):
239         task = [timestamp, db_name, function, args, kwargs]
240         heapq.heappush(cls.__tasks, task)
241         cls.__tasks_by_db.setdefault(db_name, []).append(task)
242
243     @classmethod
244     def cancel(cls, db_name):
245         """Cancel all tasks for a given database. If None is passed, all tasks are cancelled"""
246         cls._logger.debug("Cancel timers for %s db", db_name or 'all')
247         if db_name is None:
248             cls.__tasks, cls.__tasks_by_db = [], {}
249         else:
250             if db_name in cls.__tasks_by_db:
251                 for task in cls.__tasks_by_db[db_name]:
252                     task[0] = 0
253
254     @classmethod
255     def quit(cls):
256         cls.cancel(None)
257
258     @classmethod
259     def runner(cls):
260         """Neverending function (intended to be ran in a dedicated thread) that
261            checks every 60 seconds tasks to run. TODO: make configurable
262         """
263         current_thread = threading.currentThread()
264         while True:
265             while cls.__tasks and cls.__tasks[0][0] < time.time():
266                 task = heapq.heappop(cls.__tasks)
267                 timestamp, dbname, function, args, kwargs = task
268                 cls.__tasks_by_db[dbname].remove(task)
269                 if not timestamp:
270                     # null timestamp -> cancelled task
271                     continue
272                 current_thread.dbname = dbname   # hack hack
273                 cls._logger.debug("Run %s.%s(*%s, **%s)", function.im_class.__name__, function.func_name, args, kwargs)
274                 delattr(current_thread, 'dbname')
275                 task_thread = threading.Thread(target=function, name='netsvc.Agent.task', args=args, kwargs=kwargs)
276                 # force non-daemon task threads (the runner thread must be daemon, and this property is inherited by default)
277                 task_thread.setDaemon(False)
278                 task_thread.start()
279                 time.sleep(1)
280             time.sleep(60)
281
282 def start_agent():
283     agent_runner = threading.Thread(target=Agent.runner, name="netsvc.Agent.runner")
284     # the agent runner is a typical daemon thread, that will never quit and must be
285     # terminated when the main process exits - with no consequence (the processing
286     # threads it spawns are not marked daemon)
287     agent_runner.setDaemon(True)
288     agent_runner.start()
289
290 import traceback
291
292 class Server:
293     """ Generic interface for all servers with an event loop etc.
294         Override this to impement http, net-rpc etc. servers.
295
296         Servers here must have threaded behaviour. start() must not block,
297         there is no run().
298     """
299     __is_started = False
300     __servers = []
301     __starter_threads = []
302
303     # we don't want blocking server calls (think select()) to
304     # wait forever and possibly prevent exiting the process,
305     # but instead we want a form of polling/busy_wait pattern, where
306     # _server_timeout should be used as the default timeout for
307     # all I/O blocking operations
308     _busywait_timeout = 0.5
309
310
311     __logger = logging.getLogger('server')
312
313     def __init__(self):
314         Server.__servers.append(self)
315         if Server.__is_started:
316             # raise Exception('All instances of servers must be inited before the startAll()')
317             # Since the startAll() won't be called again, allow this server to
318             # init and then start it after 1sec (hopefully). Register that
319             # timer thread in a list, so that we can abort the start if quitAll
320             # is called in the meantime
321             t = threading.Timer(1.0, self._late_start)
322             t.name = 'Late start timer for %s' % str(self.__class__)
323             Server.__starter_threads.append(t)
324             t.start()
325
326     def start(self):
327         self.__logger.debug("called stub Server.start")
328
329     def _late_start(self):
330         self.start()
331         for thr in Server.__starter_threads:
332             if thr.finished.is_set():
333                 Server.__starter_threads.remove(thr)
334
335     def stop(self):
336         self.__logger.debug("called stub Server.stop")
337
338     def stats(self):
339         """ This function should return statistics about the server """
340         return "%s: No statistics" % str(self.__class__)
341
342     @classmethod
343     def startAll(cls):
344         if cls.__is_started:
345             return
346         cls.__logger.info("Starting %d services" % len(cls.__servers))
347         for srv in cls.__servers:
348             srv.start()
349         cls.__is_started = True
350
351     @classmethod
352     def quitAll(cls):
353         if not cls.__is_started:
354             return
355         cls.__logger.info("Stopping %d services" % len(cls.__servers))
356         for thr in cls.__starter_threads:
357             if not thr.finished.is_set():
358                 thr.cancel()
359             cls.__starter_threads.remove(thr)
360
361         for srv in cls.__servers:
362             srv.stop()
363         cls.__is_started = False
364
365     @classmethod
366     def allStats(cls):
367         res = ["Servers %s" % ('stopped', 'started')[cls.__is_started]]
368         res.extend(srv.stats() for srv in cls.__servers)
369         return '\n'.join(res)
370
371     def _close_socket(self):
372         close_socket(self.socket)
373
374 class OpenERPDispatcherException(Exception):
375     def __init__(self, exception, traceback):
376         self.exception = exception
377         self.traceback = traceback
378
379 def replace_request_password(args):
380     # password is always 3rd argument in a request, we replace it in RPC logs
381     # so it's easier to forward logs for diagnostics/debugging purposes...
382     args = list(args)
383     if len(args) > 2:
384         args[2] = '*'
385     return args
386
387 def log(title, msg, channel=logging.DEBUG_RPC, depth=None, fn=""):
388     logger = logging.getLogger(title)
389     if logger.isEnabledFor(channel):
390         indent=''
391         indent_after=' '*len(fn)
392         for line in (fn+pformat(msg, depth=depth)).split('\n'):
393             logger.log(channel, indent+line)
394             indent=indent_after
395
396 # This class is used to dispatch a RPC to a service. So it is used
397 # for both XMLRPC (with a SimpleXMLRPCRequestHandler), and NETRPC.
398 # The service (ExportService) will then dispatch on the method name.
399 # This can be re-written as a single function
400 #   def dispatch(self, service_name, method, params, auth_provider).
401 class OpenERPDispatcher:
402     def log(self, title, msg, channel=logging.DEBUG_RPC, depth=None, fn=""):
403         log(title, msg, channel=channel, depth=depth, fn=fn)
404     def dispatch(self, service_name, method, params):
405         try:
406             auth = getattr(self, 'auth_provider', None)
407             logger = logging.getLogger('result')
408             start_time = end_time = 0
409             if logger.isEnabledFor(logging.DEBUG_RPC_ANSWER):
410                 self.log('service', tuple(replace_request_password(params)), depth=None, fn='%s.%s'%(service_name,method))
411             if logger.isEnabledFor(logging.DEBUG_RPC):
412                 start_time = time.time()
413             result = ExportService.getService(service_name).dispatch(method, auth, params)
414             if logger.isEnabledFor(logging.DEBUG_RPC):
415                 end_time = time.time()
416             if not logger.isEnabledFor(logging.DEBUG_RPC_ANSWER):
417                 self.log('service (%.3fs)' % (end_time - start_time), tuple(replace_request_password(params)), depth=1, fn='%s.%s'%(service_name,method))
418             self.log('execution time', '%.3fs' % (end_time - start_time), channel=logging.DEBUG_RPC_ANSWER)
419             self.log('result', result, channel=logging.DEBUG_RPC_ANSWER)
420             return result
421         except Exception, e:
422             self.log('exception', tools.exception_to_unicode(e))
423             tb = getattr(e, 'traceback', sys.exc_info())
424             tb_s = "".join(traceback.format_exception(*tb))
425             if tools.config['debug_mode'] and isinstance(tb[2], types.TracebackType):
426                 import pdb
427                 pdb.post_mortem(tb[2])
428             raise OpenERPDispatcherException(e, tb_s)
429
430 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: