[MERGE] backport from stable
[odoo/odoo.git] / bin / sql_db.py
index 22e2065..93b031e 100644 (file)
@@ -1,29 +1,30 @@
-# -*- encoding: utf-8 -*-
+# -*- coding: utf-8 -*-
 ##############################################################################
-#
-#    OpenERP, Open Source Management Solution  
-#    Copyright (C) 2004-2008 Tiny SPRL (<http://tiny.be>). All Rights Reserved
-#    $Id$
+#    
+#    OpenERP, Open Source Management Solution
+#    Copyright (C) 2004-2009 Tiny SPRL (<http://tiny.be>).
 #
 #    This program is free software: you can redistribute it and/or modify
-#    it under the terms of the GNU General Public License as published by
-#    the Free Software Foundation, either version 3 of the License, or
-#    (at your option) any later version.
+#    it under the terms of the GNU Affero General Public License as
+#    published by the Free Software Foundation, either version 3 of the
+#    License, or (at your option) any later version.
 #
 #    This program is distributed in the hope that it will be useful,
 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-#    GNU General Public License for more details.
+#    GNU Affero General Public License for more details.
 #
-#    You should have received a copy of the GNU General Public License
-#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#    You should have received a copy of the GNU Affero General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.     
 #
 ##############################################################################
 
+__all__ = ['db_connect', 'close_db']
+
 import netsvc
-from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_SERIALIZABLE
-from psycopg2.pool import ThreadedConnectionPool
+from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED, ISOLATION_LEVEL_SERIALIZABLE
 from psycopg2.psycopg1 import cursor as psycopg1cursor
+from psycopg2.pool import PoolError
 
 import psycopg2.extensions
 
@@ -49,62 +50,71 @@ psycopg2.extensions.register_type(psycopg2.extensions.new_type((700, 701, 1700,)
 
 
 import tools
-import re
+from tools.func import wraps
+from datetime import datetime as mdt
+from datetime import timedelta
+import threading
 
-from mx import DateTime as mdt
+import re
 re_from = re.compile('.* from "?([a-zA-Z_0-9]+)"? .*$');
 re_into = re.compile('.* into "?([a-zA-Z_0-9]+)"? .*$');
 
-def log(msg, lvl=netsvc.LOG_DEBUG):
+
+def log(msg, lvl=netsvc.LOG_DEBUG2):
     logger = netsvc.Logger()
     logger.notifyChannel('sql', lvl, msg)
 
+sql_counter = 0
+
 class Cursor(object):
     IN_MAX = 1000
-    sql_from_log = {}
-    sql_into_log = {}
-    sql_log = False
-    count = 0
-    
-    def check(f):
-        from tools.func import wraps
 
+    def check(f):
         @wraps(f)
         def wrapper(self, *args, **kwargs):
-            if not hasattr(self, '_obj'):
-                raise psycopg2.ProgrammingError('Unable to use the cursor after having closing it')
+            if self.__closed:
+                raise psycopg2.ProgrammingError('Unable to use the cursor after having closed it')
             return f(self, *args, **kwargs)
         return wrapper
 
-    def __init__(self, pool):
+    def __init__(self, pool, dbname, serialized=False):
+        self.sql_from_log = {}
+        self.sql_into_log = {}
+        self.sql_log = False
+        self.sql_log_count = 0
+
+        self.__closed = True    # avoid the call of close() (by __del__) if an exception
+                                # is raised by any of the following initialisations
         self._pool = pool
-        self._cnx = pool.getconn()
+        self.dbname = dbname
+        self._serialized = serialized
+        self._cnx = pool.borrow(dsn(dbname))
         self._obj = self._cnx.cursor(cursor_factory=psycopg1cursor)
+        self.__closed = False   # real initialisation value
         self.autocommit(False)
-        self.dbname = pool.dbname
-        
-        from inspect import stack
-        self.__caller = tuple(stack()[2][1:3])
-        
+
+        if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
+            from inspect import stack
+            self.__caller = tuple(stack()[2][1:3])
+
     def __del__(self):
-        if hasattr(self, '_obj'):
+        if not self.__closed:
             # Oops. 'self' has not been closed explicitly.
-            # The cursor will be deleted by the garbage collector, 
+            # The cursor will be deleted by the garbage collector,
             # but the database connection is not put back into the connection
             # pool, preventing some operation on the database like dropping it.
             # This can also lead to a server overload.
-            msg = "Cursor not closed explicitly\n"  \
-                  "Cursor was created at %s:%s" % self.__caller
-
-            log(msg, netsvc.LOG_WARNING)
+            if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
+                msg = "Cursor not closed explicitly\n"  \
+                      "Cursor was created at %s:%s" % self.__caller
+                log(msg, netsvc.LOG_WARNING)
             self.close()
 
     @check
     def execute(self, query, params=None):
-        self.count+=1
         if '%d' in query or '%f' in query:
             log(query, netsvc.LOG_WARNING)
-            log("SQL queries mustn't containt %d or %f anymore. Use only %s", netsvc.LOG_WARNING)
+            log("SQL queries cannot contain %d or %f anymore. Use only %s", netsvc.LOG_WARNING)
             if params:
                 query = query.replace('%d', '%s').replace('%f', '%s')
 
@@ -112,49 +122,69 @@ class Cursor(object):
             now = mdt.now()
         
         try:
+            params = params or None
             res = self._obj.execute(query, params)
-       except psycopg2.ProgrammingError, pe:
-           logger= netsvc.Logger()
-           logger.notifyChannel('sql_db', netsvc.LOG_ERROR, "Programming error: %s, in query %s" % (pe, query))
-           raise
+        except psycopg2.ProgrammingError, pe:
+            logger= netsvc.Logger()
+            logger.notifyChannel('sql_db', netsvc.LOG_ERROR, "Programming error: %s, in query %s" % (pe, query))
+            raise
+        except Exception, e:
+            log("bad query: %s" % self._obj.query)
+            log(e)
+            raise
 
         if self.sql_log:
+            delay = mdt.now() - now
+            delay = delay.seconds * 1E6 + delay.microseconds
+
             log("query: %s" % self._obj.query)
-            self.count+=1
+            self.sql_log_count+=1
             res_from = re_from.match(query.lower())
             if res_from:
                 self.sql_from_log.setdefault(res_from.group(1), [0, 0])
                 self.sql_from_log[res_from.group(1)][0] += 1
-                self.sql_from_log[res_from.group(1)][1] += mdt.now() - now
+                self.sql_from_log[res_from.group(1)][1] += delay
             res_into = re_into.match(query.lower())
             if res_into:
                 self.sql_into_log.setdefault(res_into.group(1), [0, 0])
                 self.sql_into_log[res_into.group(1)][0] += 1
-                self.sql_into_log[res_into.group(1)][1] += mdt.now() - now
+                self.sql_into_log[res_into.group(1)][1] += delay
         return res
 
     def print_log(self):
+        global sql_counter
+        sql_counter += self.sql_log_count
+        if not self.sql_log:
+            return
         def process(type):
             sqllogs = {'from':self.sql_from_log, 'into':self.sql_into_log}
-            if not sqllogs[type]:
-                return
-            sqllogitems = sqllogs[type].items()
-            sqllogitems.sort(key=lambda k: k[1][1])
             sum = 0
-            log("SQL LOG %s:" % (type,))
-            for r in sqllogitems:
-                log("table: %s: %s/%s" %(r[0], str(r[1][1]), r[1][0]))
-                sum+= r[1][1]
-            log("SUM:%s/%d" % (sum, self.count))
-            sqllogs[type].clear()
+            if sqllogs[type]:
+                sqllogitems = sqllogs[type].items()
+                sqllogitems.sort(key=lambda k: k[1][1])
+                log("SQL LOG %s:" % (type,))
+                for r in sqllogitems:
+                    delay = timedelta(microseconds=r[1][1])
+                    log("table: %s: %s/%s" %(r[0], str(delay), r[1][0]))
+                    sum+= r[1][1]
+                sqllogs[type].clear()
+            sum = timedelta(microseconds=sum)
+            log("SUM %s:%s/%d [%d]" % (type, str(sum), self.sql_log_count, sql_counter))
         process('from')
         process('into')
-        self.count = 0
+        self.sql_log_count = 0
         self.sql_log = False
 
     @check
     def close(self):
+        if not self._obj:
+            return
+
         self.print_log()
+
+        if not self._serialized:
+            self.rollback() # Ensure we close the current transaction.
+
         self._obj.close()
 
         # This force the cursor to be freed, and thus, available again. It is
@@ -163,11 +193,13 @@ class Cursor(object):
         # collected as fast as they should). The problem is probably due in
         # part because browse records keep a reference to the cursor.
         del self._obj
-        self._pool.putconn(self._cnx)
-    
+        self.__closed = True
+        self._pool.give_back(self._cnx)
+
     @check
     def autocommit(self, on):
-        self._cnx.set_isolation_level([ISOLATION_LEVEL_SERIALIZABLE, ISOLATION_LEVEL_AUTOCOMMIT][bool(on)])
+        offlevel = [ISOLATION_LEVEL_READ_COMMITTED, ISOLATION_LEVEL_SERIALIZABLE][bool(self._serialized)]
+        self._cnx.set_isolation_level([offlevel, ISOLATION_LEVEL_AUTOCOMMIT][bool(on)])
     
     @check
     def commit(self):
@@ -181,60 +213,136 @@ class Cursor(object):
     def __getattr__(self, name):
         return getattr(self._obj, name)
 
+
 class ConnectionPool(object):
-    def __init__(self, pool, dbname):
+
+    def locked(fun):
+        @wraps(fun)
+        def _locked(self, *args, **kwargs):
+            self._lock.acquire()
+            try:
+                return fun(self, *args, **kwargs)
+            finally:
+                self._lock.release()
+        return _locked
+
+
+    def __init__(self, maxconn=64):
+        self._connections = []
+        self._maxconn = max(maxconn, 1)
+        self._lock = threading.Lock()
+        self._logger = netsvc.Logger()
+
+    def _log(self, msg):
+        #self._logger.notifyChannel('ConnectionPool', netsvc.LOG_INFO, msg)
+        pass
+    def _debug(self, msg):
+        #self._logger.notifyChannel('ConnectionPool', netsvc.LOG_DEBUG, msg)
+        pass
+
+    @locked
+    def borrow(self, dsn):
+        self._log('Borrow connection to %s' % (dsn,))
+
+        result = None
+        for i, (cnx, used) in enumerate(self._connections):
+            if not used and dsn_are_equals(cnx.dsn, dsn):
+                self._debug('Existing connection found at index %d' % i)
+
+                self._connections.pop(i)
+                self._connections.append((cnx, True))
+
+                result = cnx
+                break
+        if result:
+            return result
+
+        if len(self._connections) >= self._maxconn:
+            # try to remove the older connection not used
+            for i, (cnx, used) in enumerate(self._connections):
+                if not used:
+                    self._debug('Removing old connection at index %d: %s' % (i, cnx.dsn))
+                    self._connections.pop(i)
+                    break
+            else:
+                # note: this code is called only if the for loop has completed (no break)
+                raise PoolError('Connection Pool Full')
+
+        self._debug('Create new connection')
+        result = psycopg2.connect(dsn=dsn)
+        self._connections.append((result, True))
+        return result
+
+    @locked
+    def give_back(self, connection):
+        self._log('Give back connection to %s' % (connection.dsn,))
+        for i, (cnx, used) in enumerate(self._connections):
+            if cnx is connection:
+                self._connections.pop(i)
+                self._connections.append((cnx, False))
+                break
+        else:
+            raise PoolError('This connection does not below to the pool')
+
+    @locked
+    def close_all(self, dsn):
+        for i, (cnx, used) in tools.reverse_enumerate(self._connections):
+            if dsn_are_equals(cnx.dsn, dsn):
+                cnx.close()
+                self._connections.pop(i)
+
+
+class Connection(object):
+    __LOCKS = {}
+
+    def __init__(self, pool, dbname, unique=False):
         self.dbname = dbname
         self._pool = pool
+        self._unique = unique
+        if unique:
+            if dbname not in self.__LOCKS:
+                self.__LOCKS[dbname] = threading.Lock()
+            self.__LOCKS[dbname].acquire()
 
-    def cursor(self):
-        return Cursor(self)
+    def __del__(self):
+        if self._unique:
+            close_db(self.dbname)
+            self.__LOCKS[self.dbname].release()
 
-    def __getattr__(self, name):
-        return getattr(self._pool, name)
+    def cursor(self, serialized=False):
+        return Cursor(self._pool, self.dbname, serialized=serialized)
 
-class PoolManager(object):
-    _pools = {}
-    _dsn = None
-    maxconn =  int(tools.config['db_maxconn']) or 64
-    
-    @classmethod 
-    def dsn(cls, db_name):
-        if cls._dsn is None:
-            cls._dsn = ''
-            for p in ('host', 'port', 'user', 'password'):
-                cfg = tools.config['db_' + p]
-                if cfg:
-                    cls._dsn += '%s=%s ' % (p, cfg)
-        return '%s dbname=%s' % (cls._dsn, db_name)
-
-    @classmethod
-    def get(cls, db_name):
-        if db_name not in cls._pools:
-            logger = netsvc.Logger()
-            try:
-                logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Connecting to %s' % (db_name,))
-                cls._pools[db_name] = ConnectionPool(ThreadedConnectionPool(1, cls.maxconn, cls.dsn(db_name)), db_name)
-            except Exception, e:
-                logger.notifyChannel('dbpool', netsvc.LOG_ERROR, 'Unable to connect to %s: %s' %
-                                     (db_name, str(e)))
-                raise
-        return cls._pools[db_name]
-
-    @classmethod
-    def close(cls, db_name):
-        if db_name in cls._pools:
-            logger = netsvc.Logger()
-            logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Closing all connections to %s' % (db_name,))
-            cls._pools[db_name].closeall()
-            del cls._pools[db_name]
-
-def db_connect(db_name, serialize=0):
-    return PoolManager.get(db_name)
+    def serialized_cursor(self):
+        return self.cursor(True)
+
+
+_dsn = ''
+for p in ('host', 'port', 'user', 'password'):
+    cfg = tools.config['db_' + p]
+    if cfg:
+        _dsn += '%s=%s ' % (p, cfg)
+
+def dsn(db_name):
+    return '%sdbname=%s' % (_dsn, db_name)
+
+def dsn_are_equals(first, second):
+    def key(dsn):
+        k = dict(x.split('=', 1) for x in dsn.strip().split())
+        k.pop('password', None) # password is not relevant
+        return k
+    return key(first) == key(second)
+
+
+_Pool = ConnectionPool(int(tools.config['db_maxconn']))
+
+def db_connect(db_name):
+    unique = db_name in ['template1', 'template0']
+    return Connection(_Pool, db_name, unique)
 
 def close_db(db_name):
-    PoolManager.close(db_name)
-    tools.cache.clean_cache_for_db(db_name)
-    
+    _Pool.close_all(dsn(db_name))
+    tools.cache.clean_caches_for_db(db_name)
+
 
 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: