[FIX]crm : fixed problems regardig email_from/email_to
[odoo/odoo.git] / bin / sql_db.py
index f3b8c86..17bfa34 100644 (file)
@@ -1,29 +1,30 @@
-# -*- encoding: utf-8 -*-
+# -*- coding: utf-8 -*-
 ##############################################################################
-#
-#    OpenERP, Open Source Management Solution  
-#    Copyright (C) 2004-2009 Tiny SPRL (<http://tiny.be>). All Rights Reserved
-#    $Id$
+#    
+#    OpenERP, Open Source Management Solution
+#    Copyright (C) 2004-2009 Tiny SPRL (<http://tiny.be>).
 #
 #    This program is free software: you can redistribute it and/or modify
-#    it under the terms of the GNU General Public License as published by
-#    the Free Software Foundation, either version 3 of the License, or
-#    (at your option) any later version.
+#    it under the terms of the GNU Affero General Public License as
+#    published by the Free Software Foundation, either version 3 of the
+#    License, or (at your option) any later version.
 #
 #    This program is distributed in the hope that it will be useful,
 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-#    GNU General Public License for more details.
+#    GNU Affero General Public License for more details.
 #
-#    You should have received a copy of the GNU General Public License
-#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#    You should have received a copy of the GNU Affero General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.     
 #
 ##############################################################################
 
-import netsvc
+__all__ = ['db_connect', 'close_db']
+
+import logging
 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED, ISOLATION_LEVEL_SERIALIZABLE
-from psycopg2.pool import ThreadedConnectionPool
 from psycopg2.psycopg1 import cursor as psycopg1cursor
+from psycopg2.pool import PoolError
 
 import psycopg2.extensions
 
@@ -49,116 +50,135 @@ psycopg2.extensions.register_type(psycopg2.extensions.new_type((700, 701, 1700,)
 
 
 import tools
-import re
+from tools.func import wraps
+from datetime import datetime as mdt
+from datetime import timedelta
+import threading
+from inspect import stack
 
-from mx import DateTime as mdt
+import re
 re_from = re.compile('.* from "?([a-zA-Z_0-9]+)"? .*$');
 re_into = re.compile('.* into "?([a-zA-Z_0-9]+)"? .*$');
 
-def log(msg, lvl=netsvc.LOG_DEBUG):
-    logger = netsvc.Logger()
-    logger.notifyChannel('sql', lvl, msg)
+sql_counter = 0
 
 class Cursor(object):
     IN_MAX = 1000
-    sql_from_log = {}
-    sql_into_log = {}
-    sql_log = False
-    count = 0
-    
-    def check(f):
-        from tools.func import wraps
+    __logger = logging.getLogger('db.cursor')
 
+    def check(f):
         @wraps(f)
         def wrapper(self, *args, **kwargs):
             if self.__closed:
-                raise psycopg2.ProgrammingError('Unable to use the cursor after having closing it')
+                raise psycopg2.ProgrammingError('Unable to use the cursor after having closed it')
             return f(self, *args, **kwargs)
         return wrapper
 
-    def __init__(self, pool, serialized=False):
+    def __init__(self, pool, dbname, serialized=False):
+        self.sql_from_log = {}
+        self.sql_into_log = {}
+        self.sql_log = False
+        self.sql_log_count = 0
+        self.__closed = True    # avoid the call of close() (by __del__) if an exception
+                                # is raised by any of the following initialisations
         self._pool = pool
+        self.dbname = dbname
         self._serialized = serialized
-        self._cnx = pool.getconn()
+        self._cnx = pool.borrow(dsn(dbname))
         self._obj = self._cnx.cursor(cursor_factory=psycopg1cursor)
-        self.__closed = False
+        self.__closed = False   # real initialisation value
         self.autocommit(False)
-        self.dbname = pool.dbname
+        self.__caller = tuple(stack()[2][1:3])
 
-        if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
-            from inspect import stack
-            self.__caller = tuple(stack()[2][1:3])
-        
     def __del__(self):
         if not self.__closed:
-            if tools.config['log_level'] in (netsvc.LOG_DEBUG, netsvc.LOG_DEBUG_RPC):
-                # Oops. 'self' has not been closed explicitly.
-                # The cursor will be deleted by the garbage collector, 
-                # but the database connection is not put back into the connection
-                # pool, preventing some operation on the database like dropping it.
-                # This can also lead to a server overload.
-                msg = "Cursor not closed explicitly\n"  \
-                      "Cursor was created at %s:%s" % self.__caller
-
-                log(msg, netsvc.LOG_WARNING)
+            # Oops. 'self' has not been closed explicitly.
+            # The cursor will be deleted by the garbage collector,
+            # but the database connection is not put back into the connection
+            # pool, preventing some operation on the database like dropping it.
+            # This can also lead to a server overload.
+            msg = "Cursor not closed explicitly\n"  \
+                  "Cursor was created at %s:%s"
+            self.__logger.warn(msg, *self.__caller)
             self.close()
 
     @check
     def execute(self, query, params=None):
-        self.count+=1
         if '%d' in query or '%f' in query:
-            log(query, netsvc.LOG_WARNING)
-            log("SQL queries mustn't containt %d or %f anymore. Use only %s", netsvc.LOG_WARNING)
+            self.__logger.warn(query)
+            self.__logger.warn("SQL queries cannot contain %d or %f anymore. "
+                               "Use only %s")
             if params:
                 query = query.replace('%d', '%s').replace('%f', '%s')
 
         if self.sql_log:
             now = mdt.now()
-        
+
         try:
+            params = params or None
             res = self._obj.execute(query, params)
-        except Exception, e:
-            log("bad query: %s" % self._obj.query)
-            log(e)
+        except psycopg2.ProgrammingError, pe:
+            self.__logger.error("Programming error: %s, in query %s" % (pe, query))
+            raise
+        except Exception:
+            self.__logger.exception("bad query: %s", self._obj.query)
             raise
 
         if self.sql_log:
-            log("query: %s" % self._obj.query)
-            self.count+=1
+            delay = mdt.now() - now
+            delay = delay.seconds * 1E6 + delay.microseconds
+
+            self.__logger.debug("query: %s", self._obj.query)
+            self.sql_log_count+=1
             res_from = re_from.match(query.lower())
             if res_from:
                 self.sql_from_log.setdefault(res_from.group(1), [0, 0])
                 self.sql_from_log[res_from.group(1)][0] += 1
-                self.sql_from_log[res_from.group(1)][1] += mdt.now() - now
+                self.sql_from_log[res_from.group(1)][1] += delay
             res_into = re_into.match(query.lower())
             if res_into:
                 self.sql_into_log.setdefault(res_into.group(1), [0, 0])
                 self.sql_into_log[res_into.group(1)][0] += 1
-                self.sql_into_log[res_into.group(1)][1] += mdt.now() - now
+                self.sql_into_log[res_into.group(1)][1] += delay
         return res
 
     def print_log(self):
+        global sql_counter
+        sql_counter += self.sql_log_count
+        if not self.sql_log:
+            return
         def process(type):
             sqllogs = {'from':self.sql_from_log, 'into':self.sql_into_log}
-            if not sqllogs[type]:
-                return
-            sqllogitems = sqllogs[type].items()
-            sqllogitems.sort(key=lambda k: k[1][1])
             sum = 0
-            log("SQL LOG %s:" % (type,))
-            for r in sqllogitems:
-                log("table: %s: %s/%s" %(r[0], str(r[1][1]), r[1][0]))
-                sum+= r[1][1]
-            log("SUM:%s/%d" % (sum, self.count))
+            if sqllogs[type]:
+                sqllogitems = sqllogs[type].items()
+                sqllogitems.sort(key=lambda k: k[1][1])
+                self.__logger.debug("SQL LOG %s:", type)
+                for r in sqllogitems:
+                    delay = timedelta(microseconds=r[1][1])
+                    self.__logger.debug("table: %s: %s/%s",
+                                        r[0], delay, r[1][0])
+                    sum+= r[1][1]
+                sqllogs[type].clear()
+            sum = timedelta(microseconds=sum)
+            self.__logger.debug("SUM %s:%s/%d [%d]",
+                                type, sum, self.sql_log_count, sql_counter)
             sqllogs[type].clear()
         process('from')
         process('into')
-        self.count = 0
+        self.sql_log_count = 0
         self.sql_log = False
 
     @check
     def close(self):
+        if not self._obj:
+            return
+
         self.print_log()
+
+        if not self._serialized:
+            self.rollback() # Ensure we close the current transaction.
+
         self._obj.close()
 
         # This force the cursor to be freed, and thus, available again. It is
@@ -168,17 +188,17 @@ class Cursor(object):
         # part because browse records keep a reference to the cursor.
         del self._obj
         self.__closed = True
-        self._pool.putconn(self._cnx)
-    
+        self._pool.give_back(self._cnx)
+
     @check
     def autocommit(self, on):
         offlevel = [ISOLATION_LEVEL_READ_COMMITTED, ISOLATION_LEVEL_SERIALIZABLE][bool(self._serialized)]
         self._cnx.set_isolation_level([offlevel, ISOLATION_LEVEL_AUTOCOMMIT][bool(on)])
-    
+
     @check
     def commit(self):
         return self._cnx.commit()
-    
+
     @check
     def rollback(self):
         return self._cnx.rollback()
@@ -187,63 +207,140 @@ class Cursor(object):
     def __getattr__(self, name):
         return getattr(self._obj, name)
 
+
 class ConnectionPool(object):
+
+    __logger = logging.getLogger('db.connection_pool')
+
+    def locked(fun):
+        @wraps(fun)
+        def _locked(self, *args, **kwargs):
+            self._lock.acquire()
+            try:
+                return fun(self, *args, **kwargs)
+            finally:
+                self._lock.release()
+        return _locked
+
+
+    def __init__(self, maxconn=64):
+        self._connections = []
+        self._maxconn = max(maxconn, 1)
+        self._lock = threading.Lock()
+
+    def __repr__(self):
+        used = len([1 for c, u in self._connections[:] if u])
+        count = len(self._connections)
+        return "ConnectionPool(used=%d/count=%d/max=%d)" % (used, count, self._maxconn)
+
+    def _debug(self, msg):
+        self.__logger.debug(repr(self))
+        self.__logger.debug(msg)
+
+    @locked
+    def borrow(self, dsn):
+        self._debug('Borrow connection to %s' % (dsn,))
+
+        result = None
+        for i, (cnx, used) in enumerate(self._connections):
+            if not used and dsn_are_equals(cnx.dsn, dsn):
+                self._debug('Existing connection found at index %d' % i)
+
+                self._connections.pop(i)
+                self._connections.append((cnx, True))
+
+                result = cnx
+                break
+        if result:
+            return result
+
+        if len(self._connections) >= self._maxconn:
+            # try to remove the oldest connection not used
+            for i, (cnx, used) in enumerate(self._connections):
+                if not used:
+                    self._debug('Removing old connection at index %d: %s' % (i, cnx.dsn))
+                    self._connections.pop(i)
+                    break
+            else:
+                # note: this code is called only if the for loop has completed (no break)
+                raise PoolError('The Connection Pool Is Full')
+
+        self._debug('Create new connection')
+        result = psycopg2.connect(dsn=dsn)
+        self._connections.append((result, True))
+        return result
+
+    @locked
+    def give_back(self, connection):
+        self._debug('Give back connection to %s' % (connection.dsn,))
+        for i, (cnx, used) in enumerate(self._connections):
+            if cnx is connection:
+                self._connections.pop(i)
+                self._connections.append((cnx, False))
+                break
+        else:
+            raise PoolError('This connection does not below to the pool')
+
+    @locked
+    def close_all(self, dsn):
+        self._debug('Close all connections to %s' % (dsn,))
+        for i, (cnx, used) in tools.reverse_enumerate(self._connections):
+            if dsn_are_equals(cnx.dsn, dsn):
+                cnx.close()
+                self._connections.pop(i)
+
+
+class Connection(object):
+    __logger = logging.getLogger('db.connection')
+
     def __init__(self, pool, dbname):
         self.dbname = dbname
         self._pool = pool
 
-    def cursor(self):
-        return Cursor(self)
+    def cursor(self, serialized=False):
+        cursor_type = serialized and 'serialized ' or ''
+        self.__logger.debug('create %scursor to "%s"' % (cursor_type, self.dbname,))
+        return Cursor(self._pool, self.dbname, serialized=serialized)
 
     def serialized_cursor(self):
-        return Cursor(self, True)
+        return self.cursor(True)
 
-    def __getattr__(self, name):
-        return getattr(self._pool, name)
-
-class PoolManager(object):
-    _pools = {}
-    _dsn = None
-    maxconn =  int(tools.config['db_maxconn']) or 64
-    
-    @classmethod 
-    def dsn(cls, db_name):
-        if cls._dsn is None:
-            cls._dsn = ''
-            for p in ('host', 'port', 'user', 'password'):
-                cfg = tools.config['db_' + p]
-                if cfg:
-                    cls._dsn += '%s=%s ' % (p, cfg)
-        return '%s dbname=%s' % (cls._dsn, db_name)
-
-    @classmethod
-    def get(cls, db_name):
-        if db_name not in cls._pools:
-            logger = netsvc.Logger()
-            try:
-                logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Connecting to %s' % (db_name,))
-                cls._pools[db_name] = ConnectionPool(ThreadedConnectionPool(1, cls.maxconn, cls.dsn(db_name)), db_name)
-            except Exception, e:
-                logger.notifyChannel('dbpool', netsvc.LOG_ERROR, 'Unable to connect to %s: %s' %
-                                     (db_name, str(e)))
-                raise
-        return cls._pools[db_name]
-
-    @classmethod
-    def close(cls, db_name):
-        if db_name in cls._pools:
-            logger = netsvc.Logger()
-            logger.notifyChannel('dbpool', netsvc.LOG_INFO, 'Closing all connections to %s' % (db_name,))
-            cls._pools[db_name].closeall()
-            del cls._pools[db_name]
+    def __nonzero__(self):
+        """Check if connection is possible"""
+        try:
+            cr = self.cursor()
+            cr.close()
+            return True
+        except:
+            return False
+
+
+_dsn = ''
+for p in ('host', 'port', 'user', 'password'):
+    cfg = tools.config['db_' + p]
+    if cfg:
+        _dsn += '%s=%s ' % (p, cfg)
+
+def dsn(db_name):
+    return '%sdbname=%s' % (_dsn, db_name)
+
+def dsn_are_equals(first, second):
+    def key(dsn):
+        k = dict(x.split('=', 1) for x in dsn.strip().split())
+        k.pop('password', None) # password is not relevant
+        return k
+    return key(first) == key(second)
+
+
+_Pool = ConnectionPool(int(tools.config['db_maxconn']))
 
 def db_connect(db_name):
-    return PoolManager.get(db_name)
+    return Connection(_Pool, db_name)
 
 def close_db(db_name):
-    PoolManager.close(db_name)
+    _Pool.close_all(dsn(db_name))
     tools.cache.clean_caches_for_db(db_name)
-    
+
 
 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: