[FIX] XML-RPC SSL
[odoo/odoo.git] / bin / netsvc.py
1 #!/usr/bin/python
2 # -*- encoding: utf-8 -*-
3 ##############################################################################
4 #
5 #    OpenERP, Open Source Management Solution
6 #    Copyright (C) 2004-2008 Tiny SPRL (<http://tiny.be>). All Rights Reserved
7 #    $Id$
8 #
9 #    This program is free software: you can redistribute it and/or modify
10 #    it under the terms of the GNU General Public License as published by
11 #    the Free Software Foundation, either version 3 of the License, or
12 #    (at your option) any later version.
13 #
14 #    This program is distributed in the hope that it will be useful,
15 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
16 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 #    GNU General Public License for more details.
18 #
19 #    You should have received a copy of the GNU General Public License
20 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22 ##############################################################################
23
24 import SimpleXMLRPCServer
25 import SocketServer
26 import logging
27 import logging.handlers
28 import os
29 import signal
30 import socket
31 import sys
32 import threading
33 import time
34 import xmlrpclib
35
36 _service = {}
37 _group = {}
38 _res_id = 1
39 _res = {}
40
41 class ServiceEndPointCall(object):
42     def __init__(self, id, method):
43         self._id = id
44         self._meth = method
45
46     def __call__(self, *args):
47         _res[self._id] = self._meth(*args)
48         return self._id
49
50
51 class ServiceEndPoint(object):
52     def __init__(self, name, id):
53         self._id = id
54         self._meth = {}
55         s = _service[name]
56         for m in s._method:
57             self._meth[m] = s._method[m]
58
59     def __getattr__(self, name):
60         return ServiceEndPointCall(self._id, self._meth[name])
61
62
63 class Service(object):
64     _serviceEndPointID = 0
65
66     def __init__(self, name, audience=''):
67         _service[name] = self
68         self.__name = name
69         self._method = {}
70         self.exportedMethods = None
71         self._response_process = None
72         self._response_process_id = None
73         self._response = None
74
75     def joinGroup(self, name):
76         if not name in _group:
77             _group[name] = {}
78         _group[name][self.__name] = self
79
80     def exportMethod(self, m):
81         if callable(m):
82             self._method[m.__name__] = m
83
84     def serviceEndPoint(self, s):
85         if Service._serviceEndPointID >= 2**16:
86             Service._serviceEndPointID = 0
87         Service._serviceEndPointID += 1
88         return ServiceEndPoint(s, self._serviceEndPointID)
89
90     def conversationId(self):
91         return 1
92
93     def processResponse(self, s, id):
94         self._response_process, self._response_process_id = s, id
95
96     def processFailure(self, s, id):
97         pass
98
99     def resumeResponse(self, s):
100         pass
101
102     def cancelResponse(self, s):
103         pass
104
105     def suspendResponse(self, s):
106         if self._response_process:
107             self._response_process(self._response_process_id,
108                                    _res[self._response_process_id])
109         self._response_process = None
110         self._response = s(self._response_process_id)
111
112     def abortResponse(self, error, description, origin, details):
113         import tools
114         if not tools.config['debug_mode']:
115             raise Exception("%s -- %s\n\n%s"%(origin, description, details))
116         else:
117             raise
118
119     def currentFailure(self, s):
120         pass
121
122
123 class LocalService(Service):
124     def __init__(self, name):
125         self.__name = name
126         try:
127             s = _service[name]
128             self._service = s
129             for m in s._method:
130                 setattr(self, m, s._method[m])
131         except KeyError, keyError:
132             Logger().notifyChannel('module', LOG_ERROR, 'This service does not exists: %s' % (str(keyError),) )
133             raise
134
135 def service_exist(name):
136     return (name in _service) and bool(_service[name])
137
138 LOG_DEBUG_RPC = 'debug_rpc'
139 LOG_DEBUG = 'debug'
140 LOG_INFO = 'info'
141 LOG_WARNING = 'warn'
142 LOG_ERROR = 'error'
143 LOG_CRITICAL = 'critical'
144
145 # add new log level below DEBUG
146 logging.DEBUG_RPC = logging.DEBUG - 1
147
148 def init_logger():
149     from tools import config
150     import os
151
152     if config['logfile']:
153         logf = config['logfile']
154         try:
155             dirname = os.path.dirname(logf)
156             if dirname and not os.path.isdir(dirname):
157                 os.makedirs(dirname)
158             handler = logging.handlers.TimedRotatingFileHandler(logf,'D',1,30)
159         except Exception, ex:
160             sys.stderr.write("ERROR: couldn't create the logfile directory\n")
161             handler = logging.StreamHandler(sys.stdout)
162     else:
163         handler = logging.StreamHandler(sys.stdout)
164
165     # create a format for log messages and dates
166     formatter = logging.Formatter('[%(asctime)s] %(levelname)s:%(name)s:%(message)s', '%a %b %d %H:%M:%S %Y')
167
168     # tell the handler to use this format
169     handler.setFormatter(formatter)
170
171     # add the handler to the root logger
172     logging.getLogger().addHandler(handler)
173     logging.getLogger().setLevel(config['log_level'])
174
175     
176     if isinstance(handler, logging.StreamHandler) and os.name != 'nt':
177         # change color of level names
178         # uses of ANSI color codes
179         # see http://pueblo.sourceforge.net/doc/manual/ansi_color_codes.html
180         # maybe use http://code.activestate.com/recipes/574451/
181         colors = ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white', None, 'default']
182         foreground = lambda f: 30 + colors.index(f)
183         background = lambda f: 40 + colors.index(f)
184
185         mapping = {
186             'DEBUG_RPC': ('blue', 'white'),
187             'DEBUG': ('blue', 'default'),
188             'INFO': ('green', 'default'),
189             'WARNING': ('yellow', 'default'),
190             'ERROR': ('red', 'default'),
191             'CRITICAL': ('white', 'red'),
192         }
193
194         for level, (fg, bg) in mapping.items():
195             msg = "\x1b[%dm\x1b[%dm%s\x1b[0m" % (foreground(fg), background(bg), level)
196             logging.addLevelName(getattr(logging, level), msg)
197
198
199 class Logger(object):
200     def notifyChannel(self, name, level, msg):
201         log = logging.getLogger(name)
202
203         if level == LOG_DEBUG_RPC and not hasattr(log, level):
204             fct = lambda msg, *args, **kwargs: log.log(logging.DEBUG_RPC, msg, *args, **kwargs)
205             setattr(log, LOG_DEBUG_RPC, fct)
206
207         level_method = getattr(log, level)
208
209         result = str(msg).strip().split('\n')
210         if len(result)>1:
211             for idx, s in enumerate(result):
212                 level_method('[%02d]: %s' % (idx+1, s,))
213         elif result:
214             level_method(result[0])
215
216 init_logger()
217
218 class Agent(object):
219     _timers = []
220     _logger = Logger()
221
222     def setAlarm(self, fn, dt, args=None, kwargs=None):
223         if not args:
224             args = []
225         if not kwargs:
226             kwargs = {}
227         wait = dt - time.time()
228         if wait > 0:
229             self._logger.notifyChannel('timers', LOG_DEBUG, "Job scheduled in %s seconds for %s.%s" % (wait, fn.im_class.__name__, fn.func_name))
230             timer = threading.Timer(wait, fn, args, kwargs)
231             timer.start()
232             self._timers.append(timer)
233         for timer in self._timers[:]:
234             if not timer.isAlive():
235                 self._timers.remove(timer)
236
237     def quit(cls):
238         for timer in cls._timers:
239             timer.cancel()
240     quit = classmethod(quit)
241
242 class xmlrpc(object):
243     class RpcGateway(object):
244         def __init__(self, name):
245             self.name = name
246
247 class GenericXMLRPCRequestHandler:
248     def log(self, title, msg):
249         from pprint import pformat
250         Logger().notifyChannel('XMLRPC-%s' % title, LOG_DEBUG_RPC, pformat(msg))
251
252     def _dispatch(self, method, params):
253         import traceback
254         try:
255             self.log('method', method)
256             self.log('params', params)
257             n = self.path.split("/")[-1]
258             s = LocalService(n)
259             m = getattr(s, method)
260             s._service._response = None
261             r = m(*params)
262             self.log('result', r)
263             res = s._service._response
264             if res is not None:
265                 r = res
266             self.log('res',r)
267             return r
268         except Exception, e:
269             self.log('exception', e)
270             tb_s = reduce(lambda x, y: x+y, traceback.format_exception(sys.exc_type, sys.exc_value, sys.exc_traceback))
271             s = str(e)
272             import tools
273             if tools.config['debug_mode']:
274                 import pdb
275                 tb = sys.exc_info()[2]
276                 pdb.post_mortem(tb)
277             raise xmlrpclib.Fault(s, tb_s)
278
279 # refactoring from Tryton (B2CK, Cedric Krier, Bertrand Chenal)
280 class SSLSocket(object):
281     def __init__(self, socket):
282         if not hasattr(socket, 'sock_shutdown'):
283             from OpenSSL import SSL
284             ctx = SSL.Context(SSL.SSLv23_METHOD)
285             ctx.use_privatekey_file('server.pkey')
286             ctx.use_certificate_file('server.cert')
287             self.socket = SSL.Connection(ctx, socket)
288         else:
289             self.socket = socket
290
291     def shutdown(self, how):
292         return self.socket.sock_shutdown(how)
293
294     def __getattr__(self, name):
295         return getattr(self.socket, name)
296
297 class SimpleXMLRPCRequestHandler(GenericXMLRPCRequestHandler, SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
298     rpc_paths = map(lambda s: '/xmlrpc/%s' % s, _service)
299
300 class SecureXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
301     def setup(self):
302         self.connection = SSLSocket(self.request)
303         self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
304         self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
305
306 class SimpleThreadedXMLRPCServer(SocketServer.ThreadingMixIn, SimpleXMLRPCServer.SimpleXMLRPCServer):
307     def server_bind(self):
308         self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
309         SimpleXMLRPCServer.SimpleXMLRPCServer.server_bind(self)
310
311 class SecureThreadedXMLRPCServer(SimpleThreadedXMLRPCServer):
312     def __init__(self, server_address, HandlerClass, logRequests=1):
313         SimpleThreadedXMLRPCServer.__init__(self, server_address, HandlerClass, logRequests)
314         self.socket = SSLSocket(socket.socket(self.address_family, self.socket_type))
315         self.server_bind()
316         self.server_activate()
317 # end of refactoring from Tryton
318
319 class HttpDaemon(threading.Thread):
320     def __init__(self, interface, port, secure=False):
321         threading.Thread.__init__(self)
322         self.__port = port
323         self.__interface = interface
324         self.secure = bool(secure)
325         handler_class = (SimpleXMLRPCRequestHandler, SecureXMLRPCRequestHandler)[self.secure]
326         server_class = (SimpleThreadedXMLRPCServer, SecureThreadedXMLRPCServer)[self.secure]
327         self.server = server_class((interface, port), handler_class, 0)
328
329     def attach(self, path, gw):
330         pass
331
332     def stop(self):
333         self.running = False
334         if os.name != 'nt':
335             self.server.socket.shutdown( hasattr(socket, 'SHUT_RDWR') and socket.SHUT_RDWR or 2 )
336         self.server.socket.close()
337
338     def run(self):
339         self.server.register_introspection_functions()
340
341         self.running = True
342         while self.running:
343             self.server.handle_request()
344         return True
345
346         # If the server need to be run recursively
347         #
348         #signal.signal(signal.SIGALRM, self.my_handler)
349         #signal.alarm(6)
350         #while True:
351         #   self.server.handle_request()
352         #signal.alarm(0)          # Disable the alarm
353
354 import tiny_socket
355 class TinySocketClientThread(threading.Thread):
356     def __init__(self, sock, threads):
357         threading.Thread.__init__(self)
358         self.sock = sock
359         self.threads = threads
360         self._logger = Logger()
361
362     def log(self, msg):
363         self._logger.notifyChannel('NETRPC', LOG_DEBUG_RPC, msg)
364
365     def run(self):
366         import traceback
367         import time
368         import select
369         self.running = True
370         try:
371             ts = tiny_socket.mysocket(self.sock)
372         except:
373             self.sock.close()
374             self.threads.remove(self)
375             return False
376         while self.running:
377             try:
378                 msg = ts.myreceive()
379             except:
380                 self.sock.close()
381                 self.threads.remove(self)
382                 return False
383             try:
384                 self.log(msg)
385                 service = LocalService(msg[0])
386                 method = getattr(service, msg[1])
387                 service._service._response = None
388                 result_from_method = method(*msg[2:])
389                 res = service._service._response
390                 if res != None:
391                     result_from_method = res
392                 self.log(result_from_method)
393                 ts.mysend(result_from_method)
394             except Exception, e:
395                 print repr(e)
396                 tb_s = reduce(lambda x, y: x+y, traceback.format_exception(sys.exc_type, sys.exc_value, sys.exc_traceback))
397                 import tools
398                 if tools.config['debug_mode']:
399                     import pdb
400                     tb = sys.exc_info()[2]
401                     pdb.post_mortem(tb)
402                 e = Exception(str(e))
403                 self.log(str(e))
404                 ts.mysend(e, exception=True, traceback=tb_s)
405             except:
406                 pass
407             self.sock.close()
408             self.threads.remove(self)
409             return True
410
411     def stop(self):
412         self.running = False
413
414
415 class TinySocketServerThread(threading.Thread):
416     def __init__(self, interface, port, secure=False):
417         threading.Thread.__init__(self)
418         self.__port = port
419         self.__interface = interface
420         self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
421         self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
422         self.socket.bind((self.__interface, self.__port))
423         self.socket.listen(5)
424         self.threads = []
425
426     def run(self):
427         import select
428         try:
429             self.running = True
430             while self.running:
431                 (clientsocket, address) = self.socket.accept()
432                 ct = TinySocketClientThread(clientsocket, self.threads)
433                 self.threads.append(ct)
434                 ct.start()
435             self.socket.close()
436         except Exception, e:
437             self.socket.close()
438             return False
439
440     def stop(self):
441         self.running = False
442         for t in self.threads:
443             t.stop()
444         try:
445             if hasattr(socket, 'SHUT_RDWR'):
446                 self.socket.shutdown(socket.SHUT_RDWR)
447             else:
448                 self.socket.shutdown(2)
449             self.socket.close()
450         except:
451             return False
452
453
454
455
456 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:
457