modify portal.write() to only perform diffs on users' groups
authorRaphael Collet <rco@openerp.com>
Mon, 28 Mar 2011 12:20:29 +0000 (14:20 +0200)
committerRaphael Collet <rco@openerp.com>
Mon, 28 Mar 2011 12:20:29 +0000 (14:20 +0200)
bzr revid: rco@openerp.com-20110328122029-cz5k0vd3t44ggsnk

addons/portal/portal.py

index 9790731..1532eb2 100644 (file)
@@ -65,27 +65,67 @@ class portal(osv.osv):
         return super(portal, self).create(cr, uid, values, context)
     
     def write(self, cr, uid, ids, values, context=None):
-        """ extend write() to assign the portal menu and groups to users """
-        user_object = self.pool.get('res.users')
+        """ extend write() to reflect menu and groups changes on users """
         
-        # first apply changes on the portals themselves
+        # analyse groups changes, and determine how to change users
+        groups_diff = []
+        for change in values.get('group_ids', []):
+            if change[0] in [0, 5, 6]:          # change creates or sets groups,
+                groups_diff = None              # must compute per-portal diff
+                break
+            if change[0] in [3, 4]:             # change add or remove group,
+                groups_diff.append(change)      # add or remove group on users
+        
+        if groups_diff is None:
+            return self._write_compute_diff(cr, uid, ids, values, context)
+        else:
+            return self._write_diff(cr, uid, ids, values, groups_diff, context)
+    
+    def _write_diff(self, cr, uid, ids, values, groups_diff, context=None):
+        """ perform write() and apply groups_diff on users """
+        # first apply portal changes
+        super(portal, self).write(cr, uid, ids, values, context)
+        
+        # then apply menu and group changes on their users
+        user_values = {}
+        if 'menu_id' in values:
+            user_values['menu_id'] = values['menu_id']
+        if groups_diff:
+            user_values['groups_id'] = groups_diff
+        
+        if user_values:
+            user_ids = []
+            for p in self.browse(cr, uid, ids, context):
+                user_ids += get_browse_ids(p.user_ids)
+            self.pool.get('res.users').write(cr, uid, user_ids, user_values, context)
+        
+        return True
+    
+    def _write_compute_diff(self, cr, uid, ids, values, context=None):
+        """ perform write(), then compute and apply groups_diff on each portal """
+        # read group_ids before write() to compute groups_diff
+        old_group_ids = {}
+        for p in self.browse(cr, uid, ids, context):
+            old_group_ids[p.id] = get_browse_ids(p.group_ids)
+        
+        # apply portal changes
         super(portal, self).write(cr, uid, ids, values, context)
         
-        # then reflect changes on the users of each portal
-        #
-        # PERFORMANCE NOTE.  The loop below performs N write() operations, where
-        # N=len(ids).  This may seem inefficient, but in practice it is not,
-        # because: (1) N is pretty small (typically N=1), and (2) it is too
-        # complex (read: bug-prone) to write those updates as a single batch.
-        #
-        plist = self.browse(cr, uid, ids, context)
-        for p in plist:
-            user_ids  = get_browse_ids(p.user_ids)
-            user_values = {
-                'menu_id': get_browse_id(p.menu_id),
-                'groups_id': [(6, 0, get_browse_ids(p.group_ids))],
-            }
-            user_object.write(cr, uid, user_ids, user_values, context)
+        # the changes to apply on users
+        user_values = {}
+        if 'menu_id' in values:
+            user_values['menu_id'] = values['menu_id']
+        
+        # compute groups_diff on each portal, and apply them on users
+        for p in self.browse(cr, uid, ids, context):
+            old_groups = set(old_group_ids[p.id])
+            new_groups = set(get_browse_ids(p.group_ids))
+            # groups_diff: [(3, UNLINKED_ID), ..., (4, LINKED_ID), ...]
+            user_values['groups_id'] = \
+                [(3, g) for g in (old_groups - new_groups)] + \
+                [(4, g) for g in (new_groups - old_groups)]
+            user_ids = get_browse_ids(p.user_ids)
+            self.pool.get('res.users').write(cr, uid, user_ids, user_values, context)
         
         return True
 
@@ -108,7 +148,7 @@ def get_browse_id(obj):
 
 def get_browse_ids(objs):
     """ return the ids of a list of browse() objects """
-    return [(obj and obj.id or default) for obj in objs]
+    return map(get_browse_id, objs)
 
 def copy_random(name):
     """ return "name [N]" for some random integer N """