Remove sql injection problem
[odoo/odoo.git] / bin / sql_db.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #    
4 #    OpenERP, Open Source Management Solution
5 #    Copyright (C) 2004-2009 Tiny SPRL (<http://tiny.be>).
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU Affero General Public License as
9 #    published by the Free Software Foundation, either version 3 of the
10 #    License, or (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU Affero General Public License for more details.
16 #
17 #    You should have received a copy of the GNU Affero General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.     
19 #
20 ##############################################################################
21
22 import netsvc
23 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED, ISOLATION_LEVEL_SERIALIZABLE
24 from psycopg2.pool import ThreadedConnectionPool
25 from psycopg2.psycopg1 import cursor as psycopg1cursor
26
27 import psycopg2.extensions
28
29 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
30
31 types_mapping = {
32     'date': (1082,),
33     'time': (1083,),
34     'datetime': (1114,),
35 }
36
37 def unbuffer(symb, cr):
38     if symb is None: return None
39     return str(symb)
40
41 def undecimalize(symb, cr):
42     if symb is None: return None
43     return float(symb)
44
45 for name, typeoid in types_mapping.items():
46     psycopg2.extensions.register_type(psycopg2.extensions.new_type(typeoid, name, lambda x, cr: x))
47 psycopg2.extensions.register_type(psycopg2.extensions.new_type((700, 701, 1700,), 'float', undecimalize))
48
49
50 import tools
51 import re
52
53 from mx import DateTime as mdt
54 re_from = re.compile('.* from "?([a-zA-Z_0-9]+)"? .*$');
55 re_into = re.compile('.* into "?([a-zA-Z_0-9]+)"? .*$');
56
57 def log(msg, lvl=netsvc.LOG_DEBUG):
58     logger = netsvc.Logger()
59     logger.notifyChannel('sql', lvl, msg)
60
61 class Cursor(object):
62     IN_MAX = 1000
63     sql_from_log = {}
64     sql_into_log = {}
65     sql_log = False
66     count = 0
67     
68     def check(f):
69         from tools.func import wraps
70
71         @wraps(f)
72         def wrapper(self, *args, **kwargs):
73             if self.__closed:
74                 raise psycopg2.ProgrammingError('Unable to use the cursor after having closing it')
75             return f(self, *args, **kwargs)
76         return wrapper
77
78     def __init__(self, pool, serialized=False):
79         self._pool = pool
80         self._serialized = serialized
81         self._cnx = pool.getconn()
82         self._obj = self._cnx.cursor(cursor_factory=psycopg1cursor)
83         self.__closed = False
84         self.autocommit(False)
85         self.dbname = pool.dbname
86
87         if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
88             from inspect import stack
89             self.__caller = tuple(stack()[2][1:3])
90         
91     def __del__(self):
92         if not self.__closed:
93             if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
94                 # Oops. 'self' has not been closed explicitly.
95                 # The cursor will be deleted by the garbage collector, 
96                 # but the database connection is not put back into the connection
97                 # pool, preventing some operation on the database like dropping it.
98                 # This can also lead to a server overload.
99                 msg = "Cursor not closed explicitly\n"  \
100                       "Cursor was created at %s:%s" % self.__caller
101
102                 log(msg, netsvc.LOG_WARNING)
103             self.close()
104
105     @check
106     def execute(self, query, params=None):
107         self.count+=1
108         if '%d' in query or '%f' in query:
109             log(query, netsvc.LOG_WARNING)
110             log("SQL queries mustn't contain %d or %f anymore. Use only %s", netsvc.LOG_WARNING)
111             if params:
112                 query = query.replace('%d', '%s').replace('%f', '%s')
113
114         if self.sql_log:
115             now = mdt.now()
116         
117         try:
118             params = params or None
119             res = self._obj.execute(query, params)
120         except Exception, e:
121             log("bad query: %s" % self._obj.query)
122             log(e)
123             raise
124
125         if self.sql_log:
126             log("query: %s" % self._obj.query)
127             self.count+=1
128             res_from = re_from.match(query.lower())
129             if res_from:
130                 self.sql_from_log.setdefault(res_from.group(1), [0, 0])
131                 self.sql_from_log[res_from.group(1)][0] += 1
132                 self.sql_from_log[res_from.group(1)][1] += mdt.now() - now
133             res_into = re_into.match(query.lower())
134             if res_into:
135                 self.sql_into_log.setdefault(res_into.group(1), [0, 0])
136                 self.sql_into_log[res_into.group(1)][0] += 1
137                 self.sql_into_log[res_into.group(1)][1] += mdt.now() - now
138         return res
139
140     def print_log(self):
141         def process(type):
142             sqllogs = {'from':self.sql_from_log, 'into':self.sql_into_log}
143             if not sqllogs[type]:
144                 return
145             sqllogitems = sqllogs[type].items()
146             sqllogitems.sort(key=lambda k: k[1][1])
147             sum = 0
148             log("SQL LOG %s:" % (type,))
149             for r in sqllogitems:
150                 log("table: %s: %s/%s" %(r[0], str(r[1][1]), r[1][0]))
151                 sum+= r[1][1]
152             log("SUM:%s/%d" % (sum, self.count))
153             sqllogs[type].clear()
154         process('from')
155         process('into')
156         self.count = 0
157         self.sql_log = False
158
159     @check
160     def close(self):
161         self.rollback() # Ensure we close the current transaction.
162         self.print_log()
163         self._obj.close()
164
165         # This force the cursor to be freed, and thus, available again. It is
166         # important because otherwise we can overload the server very easily
167         # because of a cursor shortage (because cursors are not garbage
168         # collected as fast as they should). The problem is probably due in
169         # part because browse records keep a reference to the cursor.
170         del self._obj
171         self.__closed = True
172         self._pool.putconn(self._cnx)
173     
174     @check
175     def autocommit(self, on):
176         offlevel = [ISOLATION_LEVEL_READ_COMMITTED, ISOLATION_LEVEL_SERIALIZABLE][bool(self._serialized)]
177         self._cnx.set_isolation_level([offlevel, ISOLATION_LEVEL_AUTOCOMMIT][bool(on)])
178     
179     @check
180     def commit(self):
181         return self._cnx.commit()
182     
183     @check
184     def rollback(self):
185         return self._cnx.rollback()
186
187     @check
188     def __getattr__(self, name):
189         return getattr(self._obj, name)
190
191 class ConnectionPool(object):
192     def __init__(self, pool, dbname):
193         self.dbname = dbname
194         self._pool = pool
195
196     def cursor(self):
197         return Cursor(self)
198
199     def serialized_cursor(self):
200         return Cursor(self, True)
201
202     def __getattr__(self, name):
203         return getattr(self._pool, name)
204
205 class PoolManager(object):
206     _pools = {}
207     _dsn = None
208     maxconn =  int(tools.config['db_maxconn']) or 64
209     
210     @classmethod 
211     def dsn(cls, db_name):
212         if cls._dsn is None:
213             cls._dsn = ''
214             for p in ('host', 'port', 'user', 'password'):
215                 cfg = tools.config['db_' + p]
216                 if cfg:
217                     cls._dsn += '%s=%s ' % (p, cfg)
218         return '%s dbname=%s' % (cls._dsn, db_name)
219
220     @classmethod
221     def get(cls, db_name):
222         if db_name not in cls._pools:
223             logger = netsvc.Logger()
224             try:
225                 logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Connecting to %s' % (db_name,))
226                 cls._pools[db_name] = ConnectionPool(ThreadedConnectionPool(1, cls.maxconn, cls.dsn(db_name)), db_name)
227             except Exception, e:
228                 logger.notifyChannel('dbpool', netsvc.LOG_ERROR, 'Unable to connect to %s: %s' %
229                                      (db_name, str(e)))
230                 raise
231         return cls._pools[db_name]
232
233     @classmethod
234     def close(cls, db_name):
235         if db_name in cls._pools:
236             logger = netsvc.Logger()
237             logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Closing all connections to %s' % (db_name,))
238             cls._pools[db_name].closeall()
239             del cls._pools[db_name]
240
241 def db_connect(db_name):
242     return PoolManager.get(db_name)
243
244 def close_db(db_name):
245     PoolManager.close(db_name)
246     tools.cache.clean_caches_for_db(db_name)
247     
248
249 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:
250