KERNEL: fix pooler netsvc for multi-db
[odoo/odoo.git] / bin / osv / osv.py
1 ##############################################################################
2 #
3 # Copyright (c) 2004-2006 TINY SPRL. (http://tiny.be) All Rights Reserved.
4 #                    Fabien Pinckaers <fp@tiny.Be>
5 #
6 # WARNING: This program as such is intended to be used by professional
7 # programmers who take the whole responsability of assessing all potential
8 # consequences resulting from its eventual inadequacies and bugs
9 # End users who are looking for a ready-to-use solution with commercial
10 # garantees and support are strongly adviced to contract a Free Software
11 # Service Company
12 #
13 # This program is Free Software; you can redistribute it and/or
14 # modify it under the terms of the GNU General Public License
15 # as published by the Free Software Foundation; either version 2
16 # of the License, or (at your option) any later version.
17 #
18 # This program is distributed in the hope that it will be useful,
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
21 # GNU General Public License for more details.
22 #
23 # You should have received a copy of the GNU General Public License
24 # along with this program; if not, write to the Free Software
25 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
26 #
27 ##############################################################################
28
29 #
30 # OSV: Objects Services
31 #
32
33 import orm
34 import netsvc
35 import pooler
36 import copy
37
38 import psycopg
39
40 module_list = []
41 module_class_list = {}
42 class_pool = {}
43
44 class except_osv(Exception):
45         def __init__(self, name, value, exc_type='warning'):
46                 self.name = name
47                 self.exc_type = exc_type
48                 self.value = value
49                 self.args = (exc_type,name)
50
51 class osv_pool(netsvc.Service):
52
53         def __init__(self):
54                 self.obj_pool = {}
55                 self.module_object_list = {}
56                 self.created = []
57                 self._sql_error = {}
58                 netsvc.Service.__init__(self, 'object_proxy', audience='')
59                 self.joinGroup('web-services')
60                 self.exportMethod(self.exportedMethods)
61                 self.exportMethod(self.obj_list)
62                 self.exportMethod(self.exec_workflow)
63                 self.exportMethod(self.execute)
64                 self.exportMethod(self.execute_cr)
65
66         def execute_cr(self, cr, uid, obj, method, *args, **kw):
67                 #
68                 # TODO: check security level
69                 #
70                 try:
71                         if (not method in getattr(self.obj_pool[obj],'_protected')) and len(args) and args[0] and len(self.obj_pool[obj]._inherits):
72                                 types = {obj: args[0]}
73                                 cr.execute('select inst_type,inst_id,obj_id from inherit where obj_type=%s and  obj_id in ('+','.join(map(str,args[0]))+')', (obj,))
74                                 for ty,id,id2 in cr.fetchall():
75                                         if not ty in types:
76                                                 types[ty]=[]
77                                         types[ty].append(id)
78                                         types[obj].remove(id2)
79                                 for t,ids in types.items():
80                                         if len(ids):
81                                                 t = self.obj_pool[t]
82                                                 res = getattr(t,method)(cr, uid, ids, *args[1:], **kw)
83                         else:
84                                 obj = self.obj_pool[obj]
85                                 res = getattr(obj,method)(cr, uid, *args, **kw)
86                         return res
87                 except orm.except_orm, inst:
88                         self.abortResponse(1, inst.value[0], inst.name, inst.value[1])
89                 except except_osv, inst:
90                         self.abortResponse(1, inst.name, inst.exc_type, inst.value)
91                 except psycopg.IntegrityError, inst:
92                         for key in self._sql_error.keys():
93                                 if key in inst[0]:
94                                         self.abortResponse(1, 'Constraint Error', 'warning', self._sql_error[key])
95                         self.abortResponse(1, 'Integrity Error', 'warning', inst[0])
96
97
98         def execute(self, db, uid, obj, method, *args, **kw):
99                 db, pool = pooler.get_db_and_pool(db)
100                 cr = db.cursor()
101                 try:
102                         try:
103                                 res = pool.execute_cr(cr, uid, obj, method, *args, **kw)
104                                 cr.commit()
105                         except Exception:
106                                 cr.rollback()
107                                 raise
108                 finally:
109                         cr.close()
110                 return res
111
112         def exec_workflow_cr(self, cr, uid, obj, method, *args):
113                 wf_service = netsvc.LocalService("workflow")
114                 wf_service.trg_validate(uid, obj, args[0], method, cr)
115                 return True
116
117         def exec_workflow(self, db, uid, obj, method, *args):
118                 cr = pooler.get_db(db).cursor()
119                 try:
120                         try:
121                                 res = self.exec_workflow_cr(cr, uid, obj, method, *args)
122                                 cr.commit()
123                         except Exception:
124                                 cr.rollback()
125                                 raise
126                 finally:
127                         cr.close()
128                 return res
129
130         def obj_list(self):
131                 return self.obj_pool.keys()
132
133         # adds a new object instance to the object pool. 
134         # if it already existed, the instance is replaced
135         def add(self, name, obj_inst):
136                 if self.obj_pool.has_key(name):
137                         del self.obj_pool[name]
138                 self.obj_pool[name] = obj_inst
139
140                 module = str(obj_inst.__class__)[6:]
141                 module = module[:len(module)-1]
142                 module = module.split('.')[0][2:]
143                 self.module_object_list.setdefault(module, []).append(obj_inst)
144
145         def get(self, name):
146                 obj = self.obj_pool.get(name, None)
147 # We cannot uncomment this line because it breaks initialisation since objects do not initialize
148 # in the correct order and the ORM doesnt support correctly when some objets do not exist yet
149 #               assert obj, "object %s does not exist !" % name
150                 return obj
151
152         #TODO: pass a list of modules to load
153         def instanciate(self, module):
154 #               print "module list:", module_list
155 #               for module in module_list:
156                 res = []
157                 class_list = module_class_list.get(module, [])
158 #                       if module not in self.module_object_list:
159 #               print "%s class_list:" % module, class_list
160                 for klass in class_list:
161                         res.append(klass.createInstance(self, module))
162                 return res
163 #                       else:
164 #                               print "skipping module", module
165
166 #pooler.get_pool(cr.dbname) = osv_pool()
167
168 #
169 # See if we can use the pool var instead of the class_pool one
170 #
171 class inheritor(type):
172         def __new__(cls, name, bases, d):
173                 parent_name = d.get('_inherit', None)
174                 if parent_name:
175                         parent_class = class_pool.get(parent_name)
176                         assert parent_class, "parent class %s does not exist !" % parent_name
177                         for s in ('_columns', '_defaults', '_inherits'):
178                                 new_dict = copy.copy(getattr(parent_class, s))
179                                 new_dict.update(d.get(s, {}))
180                                 d[s] = new_dict
181                         bases = (parent_class,)
182                 res = type.__new__(cls, name, bases, d)
183                 #
184                 # update _inherits of others objects
185                 #
186                 return res
187
188
189
190 class osv(orm.orm):
191         #__metaclass__ = inheritor
192
193         def __new__(cls):
194                 if not hasattr(cls, '_module'):
195                         module = str(cls)[6:]
196                         module = module[:len(module)-1]
197                         module = module.split('.')[0][2:]
198                         cls._module = module
199                 module_class_list.setdefault(cls._module, []).append(cls)
200                 class_pool[cls._name] = cls
201                 if module not in module_list:
202                         module_list.append(cls._module)
203                 return None
204                 
205         #
206         # Goal: try to apply inheritancy at the instanciation level and
207         #       put objects in the pool var
208         #
209         def createInstance(cls, pool, module):
210 #               obj = cls()
211                 parent_name = hasattr(cls, '_inherit') and cls._inherit
212                 if parent_name:
213                         parent_class = pool.get(parent_name).__class__
214                         assert parent_class, "parent class %s does not exist !" % parent_name
215                         ndict = {}
216                         for s in ('_columns', '_defaults', '_inherits'):
217                                 new_dict = copy.copy(getattr(pool.get(parent_name), s))
218                                 new_dict.update(cls.__dict__.get(s, {}))
219                                 ndict[s] = new_dict
220                         #bases = (parent_class,)
221                         #obj.__class__ += (parent_class,)
222                         #res = type.__new__(cls, name, bases, d)
223                         name = hasattr(cls,'_name') and cls._name or cls._inherit
224                         name = str(cls)
225                         cls = type(name, (cls, parent_class), ndict)
226
227                 obj = object.__new__(cls)
228                 obj.__init__(pool)
229                 return obj
230 #               return object.__new__(cls, pool)
231         createInstance = classmethod(createInstance)
232
233         def __init__(self, pool):
234 #               print "__init__", self._name, pool
235                 pool.add(self._name, self)
236                 self.pool = pool
237                 orm.orm.__init__(self)
238                 
239 #               pooler.get_pool(cr.dbname).add(self._name, self)
240 #               print self._name, module
241
242 class Cacheable(object):
243
244         _cache = {}
245         count = 0
246
247         def __delete_key(self, key):
248                 odico = self._cache
249                 for key_item in key[:-1]:
250                         odico = odico[key_item]
251                 del odico[key[-1]]
252         
253         def __add_key(self, key, value):
254                 odico = self._cache
255                 for key_item in key[:-1]:
256                         odico = odico.setdefault(key_item, {})
257                 odico[key[-1]] = value
258
259         def add(self, key, value):
260                 self.__add_key(key, value)
261         
262         def invalidate(self, key):
263                 self.__delete_key(key)
264         
265         def get(self, key):
266                 try:
267                         w = self._cache[key]
268                         return w
269                 except KeyError:
270                         return None
271         
272         def clear(self):
273                 self._cache.clear()
274                 self._items = []
275
276 def filter_dict(d, fields):
277         res = {}
278         for f in fields + ['id']:
279                 if f in d:
280                         res[f] = d[f]
281         return res
282
283 class cacheable_osv(osv, Cacheable):
284
285         _relevant = ['lang']
286
287         def __init__(self):
288                 super(cacheable_osv, self).__init__()
289         
290         def read(self, cr, user, ids, fields=[], context={}, load='_classic_read'):
291                 fields = fields or self._columns.keys()
292                 ctx = [context.get(x, False) for x in self._relevant]
293                 result, tofetch = [], []
294                 for id in ids:
295                         res = self.get(self._name, id, ctx)
296                         if not res:
297                                 tofetch.append(id)
298                         else:
299                                 result.append(filter_dict(res, fields))
300
301                 # gen the list of "local" (ie not inherited) fields which are classic or many2one
302                 nfields = filter(lambda x: x[1]._classic_write, self._columns.items())
303                 # gen the list of inherited fields
304                 inherits = map(lambda x: (x[0], x[1][2]), self._inherit_fields.items())
305                 # complete the field list with the inherited fields which are classic or many2one
306                 nfields += filter(lambda x: x[1]._classic_write, inherits)
307                 nfields = [x[0] for x in nfields]
308
309                 res = super(cacheable_osv, self).read(cr, user, tofetch, nfields, context, load)
310                 for r in res:
311                         self.add((self._name, r['id'], ctx), r)
312                         result.append(filter_dict(r, fields))
313
314                 # Appel de fonction si necessaire
315                 tofetch = []
316                 for f in fields:
317                         if f not in nfields:
318                                 tofetch.append(f)
319                 for f in tofetch:
320                         fvals = self._columns[f].get(cr, self, ids, f, user, context=context)
321                         for r in result:
322                                 r[f] = fvals[r['id']]
323
324                 # TODO: tri par self._order !!
325                 return result
326
327         def invalidate(self, key):
328                 del self._cache[key[0]][key[1]]
329         
330         def write(self, cr, user, ids, values, context={}):
331                 for id in ids:
332                         self.invalidate((self._name, id))
333                 return super(cacheable_osv, self).write(cr, user, ids, values, context)
334         
335         def unlink(self, cr, user, ids):
336                 self.clear()
337                 return super(cacheable_osv, self).unlink(cr, user, ids)
338
339 #cacheable_osv = osv
340
341 # vim:noexpandtab:
342
343 #class FakePool(object):
344 #       def __init__(self, module):
345 #               self.preferred_module = module
346         
347 #       def get(self, name):
348 #               localpool = module_objects_dict.get(self.preferred_module, {'dict': {}})['dict']
349 #               if name in localpool:
350 #                       obj = localpool[name]
351 #               else:
352 #                       obj = pooler.get_pool(cr.dbname).get(name)
353 #               return obj
354
355 #               fake_pool = self
356 #               class fake_class(obj.__class__):
357 #                       def __init__(self):
358 #                               super(fake_class, self).__init__()
359 #                               self.pool = fake_pool
360                                 
361 #               return fake_class()
362