[FIX] unlink take the context in parameter
[odoo/odoo.git] / bin / sql_db.py
1 # -*- encoding: utf-8 -*-
2 ##############################################################################
3 #
4 #    OpenERP, Open Source Management Solution   
5 #    Copyright (C) 2004-2009 Tiny SPRL (<http://tiny.be>). All Rights Reserved
6 #    $Id$
7 #
8 #    This program is free software: you can redistribute it and/or modify
9 #    it under the terms of the GNU General Public License as published by
10 #    the Free Software Foundation, either version 3 of the License, or
11 #    (at your option) any later version.
12 #
13 #    This program is distributed in the hope that it will be useful,
14 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
15 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 #    GNU General Public License for more details.
17 #
18 #    You should have received a copy of the GNU General Public License
19 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 #
21 ##############################################################################
22
23 import netsvc
24 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_SERIALIZABLE
25 from psycopg2.pool import ThreadedConnectionPool
26 from psycopg2.psycopg1 import cursor as psycopg1cursor
27
28 import psycopg2.extensions
29
30 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
31
32 types_mapping = {
33     'date': (1082,),
34     'time': (1083,),
35     'datetime': (1114,),
36 }
37
38 def unbuffer(symb, cr):
39     if symb is None: return None
40     return str(symb)
41
42 def undecimalize(symb, cr):
43     if symb is None: return None
44     return float(symb)
45
46 for name, typeoid in types_mapping.items():
47     psycopg2.extensions.register_type(psycopg2.extensions.new_type(typeoid, name, lambda x, cr: x))
48 psycopg2.extensions.register_type(psycopg2.extensions.new_type((700, 701, 1700,), 'float', undecimalize))
49
50
51 import tools
52 import re
53
54 from mx import DateTime as mdt
55 re_from = re.compile('.* from "?([a-zA-Z_0-9]+)"? .*$');
56 re_into = re.compile('.* into "?([a-zA-Z_0-9]+)"? .*$');
57
58 def log(msg, lvl=netsvc.LOG_DEBUG):
59     logger = netsvc.Logger()
60     logger.notifyChannel('sql', lvl, msg)
61
62 class Cursor(object):
63     IN_MAX = 1000
64     sql_from_log = {}
65     sql_into_log = {}
66     sql_log = False
67     count = 0
68     
69     def check(f):
70         from tools.func import wraps
71
72         @wraps(f)
73         def wrapper(self, *args, **kwargs):
74             if not hasattr(self, '_obj'):
75                 raise psycopg2.ProgrammingError('Unable to use the cursor after having closing it')
76             return f(self, *args, **kwargs)
77         return wrapper
78
79     def __init__(self, pool):
80         self._pool = pool
81         self._cnx = pool.getconn()
82         self._obj = self._cnx.cursor(cursor_factory=psycopg1cursor)
83         self.autocommit(False)
84         self.dbname = pool.dbname
85
86         if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
87             from inspect import stack
88             self.__caller = tuple(stack()[2][1:3])
89         
90     def __del__(self):
91         if hasattr(self, '_obj'):
92             if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
93                 # Oops. 'self' has not been closed explicitly.
94                 # The cursor will be deleted by the garbage collector, 
95                 # but the database connection is not put back into the connection
96                 # pool, preventing some operation on the database like dropping it.
97                 # This can also lead to a server overload.
98                 msg = "Cursor not closed explicitly\n"  \
99                       "Cursor was created at %s:%s" % self.__caller
100
101                 log(msg, netsvc.LOG_WARNING)
102             self.close()
103
104     @check
105     def execute(self, query, params=None):
106         self.count+=1
107         if '%d' in query or '%f' in query:
108             log(query, netsvc.LOG_WARNING)
109             log("SQL queries mustn't containt %d or %f anymore. Use only %s", netsvc.LOG_WARNING)
110             if params:
111                 query = query.replace('%d', '%s').replace('%f', '%s')
112
113         if self.sql_log:
114             now = mdt.now()
115         
116         res = self._obj.execute(query, params)
117
118         if self.sql_log:
119             log("query: %s" % self._obj.query)
120             self.count+=1
121             res_from = re_from.match(query.lower())
122             if res_from:
123                 self.sql_from_log.setdefault(res_from.group(1), [0, 0])
124                 self.sql_from_log[res_from.group(1)][0] += 1
125                 self.sql_from_log[res_from.group(1)][1] += mdt.now() - now
126             res_into = re_into.match(query.lower())
127             if res_into:
128                 self.sql_into_log.setdefault(res_into.group(1), [0, 0])
129                 self.sql_into_log[res_into.group(1)][0] += 1
130                 self.sql_into_log[res_into.group(1)][1] += mdt.now() - now
131         return res
132
133     def print_log(self):
134         def process(type):
135             sqllogs = {'from':self.sql_from_log, 'into':self.sql_into_log}
136             if not sqllogs[type]:
137                 return
138             sqllogitems = sqllogs[type].items()
139             sqllogitems.sort(key=lambda k: k[1][1])
140             sum = 0
141             log("SQL LOG %s:" % (type,))
142             for r in sqllogitems:
143                 log("table: %s: %s/%s" %(r[0], str(r[1][1]), r[1][0]))
144                 sum+= r[1][1]
145             log("SUM:%s/%d" % (sum, self.count))
146             sqllogs[type].clear()
147         process('from')
148         process('into')
149         self.count = 0
150         self.sql_log = False
151
152     @check
153     def close(self):
154         self.print_log()
155         self._obj.close()
156
157         # This force the cursor to be freed, and thus, available again. It is
158         # important because otherwise we can overload the server very easily
159         # because of a cursor shortage (because cursors are not garbage
160         # collected as fast as they should). The problem is probably due in
161         # part because browse records keep a reference to the cursor.
162         del self._obj
163         self._pool.putconn(self._cnx)
164     
165     @check
166     def autocommit(self, on):
167         self._cnx.set_isolation_level([ISOLATION_LEVEL_SERIALIZABLE, ISOLATION_LEVEL_AUTOCOMMIT][bool(on)])
168     
169     @check
170     def commit(self):
171         return self._cnx.commit()
172     
173     @check
174     def rollback(self):
175         return self._cnx.rollback()
176
177     @check
178     def __getattr__(self, name):
179         return getattr(self._obj, name)
180
181 class ConnectionPool(object):
182     def __init__(self, pool, dbname):
183         self.dbname = dbname
184         self._pool = pool
185
186     def cursor(self):
187         return Cursor(self)
188
189     def __getattr__(self, name):
190         return getattr(self._pool, name)
191
192 class PoolManager(object):
193     _pools = {}
194     _dsn = None
195     maxconn =  int(tools.config['db_maxconn']) or 64
196     
197     @classmethod 
198     def dsn(cls, db_name):
199         if cls._dsn is None:
200             cls._dsn = ''
201             for p in ('host', 'port', 'user', 'password'):
202                 cfg = tools.config['db_' + p]
203                 if cfg:
204                     cls._dsn += '%s=%s ' % (p, cfg)
205         return '%s dbname=%s' % (cls._dsn, db_name)
206
207     @classmethod
208     def get(cls, db_name):
209         if db_name not in cls._pools:
210             logger = netsvc.Logger()
211             try:
212                 logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Connecting to %s' % (db_name,))
213                 cls._pools[db_name] = ConnectionPool(ThreadedConnectionPool(1, cls.maxconn, cls.dsn(db_name)), db_name)
214             except Exception, e:
215                 logger.notifyChannel('dbpool', netsvc.LOG_ERROR, 'Unable to connect to %s: %s' %
216                                      (db_name, str(e)))
217                 raise
218         return cls._pools[db_name]
219
220     @classmethod
221     def close(cls, db_name):
222         if db_name in cls._pools:
223             logger = netsvc.Logger()
224             logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Closing all connections to %s' % (db_name,))
225             cls._pools[db_name].closeall()
226             del cls._pools[db_name]
227
228 def db_connect(db_name):
229     return PoolManager.get(db_name)
230
231 def close_db(db_name):
232     PoolManager.close(db_name)
233     tools.cache.clean_caches_for_db(db_name)
234     
235
236 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:
237