e52b94d89fa6bf47ac8ae49198813f1dc877324a
[odoo/odoo.git] / bin / sql_db.py
1 # -*- encoding: utf-8 -*-
2 ##############################################################################
3 #
4 #    OpenERP, Open Source Management Solution   
5 #    Copyright (C) 2004-2008 Tiny SPRL (<http://tiny.be>). All Rights Reserved
6 #    $Id$
7 #
8 #    This program is free software: you can redistribute it and/or modify
9 #    it under the terms of the GNU General Public License as published by
10 #    the Free Software Foundation, either version 3 of the License, or
11 #    (at your option) any later version.
12 #
13 #    This program is distributed in the hope that it will be useful,
14 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
15 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 #    GNU General Public License for more details.
17 #
18 #    You should have received a copy of the GNU General Public License
19 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 #
21 ##############################################################################
22
23 import netsvc
24 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_SERIALIZABLE
25 from psycopg2.pool import ThreadedConnectionPool
26 from psycopg2.psycopg1 import cursor as psycopg1cursor
27
28 import psycopg2.extensions
29 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
30
31 import tools
32 import re
33
34 from mx import DateTime as mdt
35 re_from = re.compile('.* from "?([a-zA-Z_0-9]+)"? .*$');
36 re_into = re.compile('.* into "?([a-zA-Z_0-9]+)"? .*$');
37
38 def log(msg, lvl=netsvc.LOG_DEBUG):
39     logger = netsvc.Logger()
40     logger.notifyChannel('sql', lvl, msg)
41
42 class Cursor(object):
43     IN_MAX = 1000
44     sql_from_log = {}
45     sql_into_log = {}
46     sql_log = False
47     count = 0
48
49     def __init__(self, pool):
50         self._pool = pool
51         self._cnx = pool.getconn()
52         self.autocommit(False)
53         self._obj = self._cnx.cursor(cursor_factory=psycopg1cursor)
54         self.dbname = pool.dbname
55
56     def execute(self, query, params=None):
57         if not params:
58             params=()
59         def base_string(s):
60             if isinstance(s, unicode):
61                 return s.encode('utf-8')
62             return s
63         p=map(base_string, params)
64         query = base_string(query)
65
66         if '%d' in query or '%f' in query:
67             log(queyr, netsvc.LOG_WARNING)
68             log("SQL queries mustn't containt %d or %f anymore. Use only %s", netsvc.LOG_WARNING)
69             query = query.replace('%d', '%s').replace('%f', '%s')
70
71         if self.sql_log:
72             now = mdt.now()
73             log("SQL LOG query: %s" % (query,))
74             log("SQL LOG params: %r" % (p,))
75
76         res = self._obj.execute(query, p)
77
78         if self.sql_log:
79             self.count+=1
80             res_from = re_from.match(query.lower())
81             if res_from:
82                 self.sql_from_log.setdefault(res_from.group(1), [0, 0])
83                 self.sql_from_log[res_from.group(1)][0] += 1
84                 self.sql_from_log[res_from.group(1)][1] += mdt.now() - now
85             res_into = re_into.match(query.lower())
86             if res_into:
87                 self.sql_into_log.setdefault(res_into.group(1), [0, 0])
88                 self.sql_into_log[res_into.group(1)][0] += 1
89                 self.sql_into_log[res_into.group(1)][1] += mdt.now() - now
90         return res
91
92     def print_log(self, type='from'):
93         log("SQL LOG %s:" % (type,))
94         if type == 'from':
95             logs = self.sql_from_log.items()
96         else:
97             logs = self.sql_into_log.items()
98         logs.sort(lambda x, y: cmp(x[1][1], y[1][1]))
99         sum=0
100         for r in logs:
101             log("table: %s: %s/%s" %(r[0], str(r[1][1]), r[1][0]))
102             sum+= r[1][1]
103         log("SUM:%s/%d" % (sum, self.count))
104
105     def close(self):
106         if self.sql_log:
107             self.print_log('from')
108             self.print_log('into')
109         self._obj.close()
110
111         # This force the cursor to be freed, and thus, available again. It is
112         # important because otherwise we can overload the server very easily
113         # because of a cursor shortage (because cursors are not garbage
114         # collected as fast as they should). The problem is probably due in
115         # part because browse records keep a reference to the cursor.
116         del self._obj
117         self._pool.putconn(self._cnx)
118     
119     def autocommit(self, on):
120         self._cnx.set_isolation_level([ISOLATION_LEVEL_SERIALIZABLE, ISOLATION_LEVEL_AUTOCOMMIT][bool(on)])
121     
122     def commit(self):
123         return self._cnx.commit()
124     
125     def rollback(self):
126         return self._cnx.rollback()
127
128     def __getattr__(self, name):
129         return getattr(self._obj, name)
130
131 class ConnectionPool(object):
132     def __init__(self, pool, dbname):
133         self.dbname = dbname
134         self._pool = pool
135
136     def cursor(self):
137         return Cursor(self)
138
139     def __getattr__(self, name):
140         return getattr(self._pool, name)
141
142 class PoolManager(object):
143     _pools = {}
144     _dsn = None
145     maxconn =  int(tools.config['db_maxconn']) or 64
146     
147     def dsn(db_name):
148         if PoolManager._dsn is None:
149             PoolManager._dsn = ''
150             for p in ('host', 'port', 'user', 'password'):
151                 cfg = tools.config['db_' + p]
152                 if cfg:
153                     PoolManager._dsn += '%s=%s ' % (p, cfg)
154         return '%s dbname=%s' % (PoolManager._dsn, db_name)
155     dsn = staticmethod(dsn)
156
157     def get(db_name):
158         if db_name not in PoolManager._pools:
159             logger = netsvc.Logger()
160             try:
161                 logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Connecting to %s' % (db_name.lower()))
162                 PoolManager._pools[db_name] = ConnectionPool(ThreadedConnectionPool(0, PoolManager.maxconn, PoolManager.dsn(db_name)), db_name)
163             except Exception, e:
164                 logger.notifyChannel('dbpool', netsvc.LOG_CRITICAL, 'Unable to connect to %s: %r' % (db_name, e))
165                 raise
166         return PoolManager._pools[db_name]
167     get = staticmethod(get)
168
169 def db_connect(db_name, serialize=0):
170     return PoolManager.get(db_name)
171
172 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:
173