[CLEAN] Query: cleaned a bit the code. All joins now goes through the same method...
authorThibault Delavallée <tde@openerp.com>
Fri, 7 Dec 2012 15:42:04 +0000 (16:42 +0100)
committerThibault Delavallée <tde@openerp.com>
Fri, 7 Dec 2012 15:42:04 +0000 (16:42 +0100)
bzr revid: tde@openerp.com-20121207154204-mx036lpj3vdclu77

openerp/osv/expression.py
openerp/osv/orm.py
openerp/osv/query.py
openerp/tests/test_osv.py

index 703b750..8535213 100644 (file)
@@ -346,6 +346,20 @@ def generate_table_alias(src_table_alias, joined_tables=[]):
     return ('%s' % alias, '%s as %s' % (_quote(joined_tables[-1][0]), _quote(alias)))
 
 
+def get_alias_from_query(from_query):
+    """ :param string from_query: is something like :
+        - '"res_partner"' OR
+        - '"res_partner" as "res_users__partner_id"''
+        :param tuple result: (unquoted table name, unquoted alias)
+            i.e. (res_partners, res_partner) OR (res_partner, res_users__partner_id)
+    """
+    from_splitted = from_query.split(' as ')
+    if len(from_splitted) > 1:
+        return (from_splitted[0].replace('"', ''), from_splitted[1].replace('"', ''))
+    else:
+        return (from_splitted[0].replace('"', ''), from_splitted[0].replace('"', ''))
+
+
 def normalize_leaf(element):
     """ Change a term's operator to some canonical form, simplifying later
         processing. """
index d2c0a0f..30746bd 100644 (file)
@@ -2728,14 +2728,8 @@ class BaseModel(object):
         """
         inherits_field = current_model._inherits[parent_model_name]
         parent_model = self.pool.get(parent_model_name)
-        parent_alias = query.add_implicit_join((current_model, parent_model, inherits_field, 'id', inherits_field))
+        parent_alias, parent_alias_statement = query.add_join((current_model._table, parent_model._table, inherits_field, 'id', inherits_field), implicit=True)
         return parent_alias
-        # table_alias = expression.generate_table_alias(current_model, [(parent_model, inherits_field)])
-        # print '\t... _inherits_join_add trying to add %s in %s' % (table_alias, query.tables)
-        # # query.add_table(table_alias)
-        # if table_alias not in query.tables:
-        #     query.tables.append(table_alias)
-        #     query.where_clause.append('("%s".%s = %s.id)' % (current_model._table, inherits_field, table_alias))
 
     def _inherits_join_calc(self, field, query):
         """
@@ -2746,7 +2740,6 @@ class BaseModel(object):
         :param query: query object on which the JOIN should be added
         :return: qualified name of field, to be used in SELECT clause
         """
-        print '\t--> _inherits_join_calc'
         current_table = self
         parent_alias = current_table._table
         while field in current_table._inherit_fields and not field in current_table._columns:
@@ -4667,15 +4660,12 @@ class BaseModel(object):
                 :param model child_object: model object, base of the rule application
             """
             if added_clause:
-                print '--> _apply_ir_rules.apply_rule,', added_clause, added_params, added_tables, parent_model, child_object
                 if parent_model and child_object:
-                    print '\t... calling inherits-Join_add with parent_model %s, child_object %s' % (parent_model, child_object)
                     # as inherited rules are being applied, we need to add the missing JOIN
                     # to reach the parent table (if it was not JOINed yet in the query)
                     parent_alias = child_object._inherits_join_add(child_object, parent_model, query)
                     # inherited rules are applied on the external table -> need to get the alias and replace
                     parent_table = self.pool.get(parent_model)._table
-                    print parent_table, parent_alias
                     added_clause = [clause.replace('"%s"' % parent_table, '"%s"' % parent_alias) for clause in added_clause]
                     # not sure of myself here (in the ORM, this statment is quite cool)
                     new_tables = []
@@ -4686,11 +4676,9 @@ class BaseModel(object):
                             new_table = table.replace('"%s"' % parent_table, '"%s"' % parent_alias)
                         new_tables.append(new_table)
                     added_tables = new_tables
-                print '\t... adding tables %s, clause %s (params %s)' % (added_tables, added_clause, added_params)
                 query.where_clause += added_clause
                 query.where_clause_params += added_params
                 for table in added_tables:
-                    print '\t... adding table %s in %s' % (table, query.tables)
                     if table not in query.tables:
                         query.tables.append(table)
                 return True
@@ -4715,7 +4703,6 @@ class BaseModel(object):
 
         :return: the qualified field name to use in an ORDER BY clause to sort by ``order_field``
         """
-        print '_generate_m2o_order_by'
         if order_field not in self._columns and order_field in self._inherit_fields:
             # also add missing joins for reaching the table containing the m2o field
             qualified_field = self._inherits_join_calc(order_field, query)
@@ -4747,7 +4734,7 @@ class BaseModel(object):
         # Join the dest m2o table if it's not joined yet. We use [LEFT] OUTER join here
         # as we don't want to exclude results that have NULL values for the m2o
         src_table, src_field = qualified_field.replace('"', '').split('.', 1)
-        dst_alias = query.add_join((src_table, dest_model._table, src_field, 'id'), outer=True)
+        dst_alias, dst_alias_statement = query.add_join((src_table, dest_model._table, src_field, 'id', src_field), implicit=False, outer=True)
         qualify = lambda field: '"%s"."%s"' % (dst_alias, field)
         return map(qualify, m2o_order) if isinstance(m2o_order, list) else qualify(m2o_order)
 
@@ -4769,7 +4756,6 @@ class BaseModel(object):
             return order_list
         order_by_clause = ','.join(_split_order(self._order, self._table))
         if order_spec:
-            print '-->_generate_order_by beginning'
             order_by_elements = []
             self._check_qorder(order_spec)
             for order_part in order_spec.split(','):
@@ -4804,7 +4790,6 @@ class BaseModel(object):
                         order_by_elements.append("%s %s" % (inner_clause, order_direction))
             if order_by_elements:
                 order_by_clause = ",".join(order_by_elements)
-            print '-->_generate_order_by ending'
 
         return order_by_clause and (' ORDER BY %s ' % order_by_clause) or ''
 
index a07def3..7f1bd93 100644 (file)
@@ -53,9 +53,6 @@ class Query(object):
         # holds the list of tables joined using default JOIN.
         # the table names are stored double-quoted (backwards compatibility)
         self.tables = tables or []
-        # holds a mapping of table aliases:
-        #   self._table_alias_mapping = {'alias_1': 'table_name'}
-        self._table_alias_mapping = {}
 
         # holds the list of WHERE clause elements, to be joined with
         # 'AND' when generating the final query
@@ -79,29 +76,19 @@ class Query(object):
         #                                 LEFT JOIN "table_c" ON ("table_a"."table_a_col2" = "table_c"."table_c_col")
         self.joins = joins or {}
 
-    def _add_table_alias(self, table_alias):
-        pass
-
     def _get_table_aliases(self):
-        aliases = []
-        for table in self.tables:
-            if len(table.split(' as ')) > 1:
-                aliases.append(table.split(' as ')[1].replace('"', ''))
-            else:
-                aliases.append(table.replace('"', ''))
-        # print '--', aliases
-        return aliases
+        from openerp.osv.expression import get_alias_from_query
+        return [get_alias_from_query(from_statement)[1] for from_statement in self.tables]
 
     def _get_alias_mapping(self):
+        from openerp.osv.expression import get_alias_from_query
         mapping = {}
-        aliases = self._get_table_aliases()
-        for alias in aliases:
-            for table in self.tables:
-                if '"%s"' % (alias) in table:
-                    mapping.setdefault(alias, table)
+        for table in self.tables:
+            alias, statement = get_alias_from_query(table)
+            mapping[statement] = table
         return mapping
 
-    def add_new_join(self, connection, implicit=True, outer=False):
+    def add_join(self, connection, implicit=True, outer=False):
         """ Join a destination table to the current table.
 
             :param implicit: False if the join is an explicit join. This allows
@@ -127,115 +114,41 @@ class Query(object):
         """
         from openerp.osv.expression import generate_table_alias
         (lhs, table, lhs_col, col, link) = connection
-        alias, alias_statement = generate_table_alias(lhs._table, [(table._table, link)])
+        alias, alias_statement = generate_table_alias(lhs, [(table, link)])
 
         if implicit:
-            print '\t\t... Query: trying to add %s in %s (received %s)' % (alias_statement, self.tables, connection)
             if alias_statement not in self.tables:
                 self.tables.append(alias_statement)
-                condition = '("%s"."%s" = "%s"."%s")' % (lhs._table, lhs_col, alias, col)
-                print '\t\t... added %s' % (condition)
+                condition = '("%s"."%s" = "%s"."%s")' % (lhs, lhs_col, alias, col)
+                # print '\t\t... Query: added %s in %s (received %s)' % (alias_statement, self.tables, connection)
                 self.where_clause.append(condition)
-            return alias
+            else:
+                # print '\t\t... Query: not added %s in %s (received %s)' % (alias_statement, self.tables, connection)
+                # already joined
+                pass
+            return alias, alias_statement
         else:
-            (lhs, table, lhs_col, col) = connection
-            lhs = _quote(lhs)
-            table = _quote(table)
-            print connection
-            aliases = []
-            for table in self.tables:
-                if len(table.split(' as ')) > 1:
-                    aliases.append(table.split(' as ')[1])
-                else:
-                    aliases.append(table)
-            print '--', aliases
-            aliases = [table.split(' as ') for table in self.tables]
-            assert lhs in self.aliases, "Left-hand-side table %s must already be part of the query tables %s!" % (lhs, str(self.tables))
-            if table in self.tables:
+            aliases = self._get_table_aliases()
+            assert lhs in aliases, "Left-hand-side table %s must already be part of the query tables %s!" % (lhs, str(self.tables))
+            if alias_statement in self.tables:
                 # already joined, must ignore (promotion to outer and multiple joins not supported yet)
+                # print '\t\t... Query: not added %s in %s (received %s)' % (alias_statement, self.tables, connection)
                 pass
             else:
                 # add JOIN
-                self.tables.append(table)
-                self.joins.setdefault(lhs, []).append((table, lhs_col, col, outer and 'LEFT JOIN' or 'JOIN'))
-            return self
-
-    def add_implicit_join(self, connection):
-        """ Adds an implicit join. This means that left-hand table is added to the
-            Query.tables (adding a table in the from clause), and that a join
-            condition is added in Query.where_clause.
-
-            Implicit joins use expression.generate_table_alias to generate the
-            alias the the joined table.
-
-            :param connection: a tuple``(lhs, table, lhs_col, col, link)`` Please
-                refer to expression.py for more details about joins.
-        """
-        from openerp.osv.expression import generate_table_alias
-        (lhs, table, lhs_col, col, link) = connection
-        alias, alias_statement = generate_table_alias(lhs._table, [(table._table, link)])
-        print '\t\t... Query: trying to add %s in %s (received %s)' % (alias_statement, self.tables, connection)
-        if alias_statement not in self.tables:
-            self.tables.append(alias_statement)
-            condition = '("%s"."%s" = "%s"."%s")' % (lhs._table, lhs_col, alias, col)
-            print '\t\t... added %s' % (condition)
-            self.where_clause.append(condition)
-        return alias
-
-    def add_join(self, connection, outer=False):
-        """Adds the JOIN specified in ``connection``.
-
-        :param connection: a tuple ``(lhs, table, lhs_col, col)``.
-                           The join corresponds to the SQL equivalent of::
-
-                               (lhs.lhs_col = table.col)
-
-                           Note that all connection elements are strings.
-
-        :param outer: True if a LEFT OUTER JOIN should be used, if possible
-                      (no promotion to OUTER JOIN is supported in case the JOIN
-                      was already present in the query, as for the moment
-                      implicit INNER JOINs are only connected from NON-NULL
-                      columns so it would not be correct (e.g. for
-                      ``_inherits`` or when a domain criterion explicitly
-                      adds filtering)
-        """
-        from openerp.osv.expression import generate_table_alias
-        (lhs, table, lhs_col, col) = connection
-        # lhs = _quote(lhs)
-        # table = _quote(table)
-        print '\t\t... Query.add_join(): adding connection %s' % str(connection)
-
-        aliases = self._get_table_aliases()
-
-        assert lhs in aliases, "Left-hand-side table %s must already be part of the query tables %s!" % (lhs, str(self.tables))
-
-        rhs, rhs_statement = generate_table_alias(lhs, [(connection[1], connection[2])])
-        print rhs, rhs_statement
-
-        if rhs_statement in self.tables:
-            # already joined, must ignore (promotion to outer and multiple joins not supported yet)
-            pass
-        else:
-            # add JOIN
-            self.tables.append(rhs_statement)
-            self.joins.setdefault(lhs, []).append((rhs, lhs_col, col, outer and 'LEFT JOIN' or 'JOIN'))
-        return rhs
+                self.tables.append(alias_statement)
+                self.joins.setdefault(lhs, []).append((alias, lhs_col, col, outer and 'LEFT JOIN' or 'JOIN'))
+            return alias, alias_statement
 
     def get_sql(self):
-        """Returns (query_from, query_where, query_params)"""
+        """ Returns (query_from, query_where, query_params). """
+        from openerp.osv.expression import get_alias_from_query
         query_from = ''
         tables_to_process = list(self.tables)
-
         alias_mapping = self._get_alias_mapping()
 
-        # print 'tables_to_process %s' % (tables_to_process)
-        # print 'self.joins %s' % (self.joins)
-        # print 'alias_mapping %s' % (alias_mapping)
-
         def add_joins_for_table(table, query_from):
             for (dest_table, lhs_col, col, join) in self.joins.get(table, []):
-                # print dest_table
                 tables_to_process.remove(alias_mapping[dest_table])
                 query_from += ' %s %s ON ("%s"."%s" = "%s"."%s")' % \
                     (join, alias_mapping[dest_table], table, lhs_col, dest_table, col)
@@ -244,8 +157,9 @@ class Query(object):
 
         for table in tables_to_process:
             query_from += table
-            if _get_alias_from_statement(table) in self.joins:
-                query_from = add_joins_for_table(_get_alias_from_statement(table), query_from)
+            table_alias = get_alias_from_query(table)[1]
+            if table_alias in self.joins:
+                query_from = add_joins_for_table(table_alias, query_from)
             query_from += ','
         query_from = query_from[:-1]  # drop last comma
         return (query_from, " AND ".join(self.where_clause), self.where_clause_params)
index 80be816..e36495f 100644 (file)
@@ -29,8 +29,8 @@ class QueryTestCase(unittest.TestCase):
         query = Query()
         query.tables.extend(['"product_product"', '"product_template"'])
         query.where_clause.append("product_product.template_id = product_template.id")
-        query.add_join(("product_template", "product_category", "categ_id", "id"), outer=False)  # add normal join
-        query.add_join(("product_product", "res_user", "user_id", "id"), outer=True)  # outer join
+        query.add_join(("product_template", "product_category", "categ_id", "id", "categ_id"), implicit=False, outer=False)  # add normal join
+        query.add_join(("product_product", "res_user", "user_id", "id", "user_id"), implicit=False, outer=True)  # outer join
         self.assertEquals(query.get_sql()[0].strip(),
             """"product_product" LEFT JOIN "res_user" as "product_product__user_id" ON ("product_product"."user_id" = "product_product__user_id"."id"),"product_template" JOIN "product_category" as "product_template__categ_id" ON ("product_template"."categ_id" = "product_template__categ_id"."id") """.strip())
         self.assertEquals(query.get_sql()[1].strip(), """product_product.template_id = product_template.id""".strip())
@@ -39,8 +39,8 @@ class QueryTestCase(unittest.TestCase):
         query = Query()
         query.tables.extend(['"product_product"', '"product_template"'])
         query.where_clause.append("product_product.template_id = product_template.id")
-        query.add_join(("product_template", "product_category", "categ_id", "id"), outer=False)  # add normal join
-        query.add_join(("product_template__categ_id", "res_user", "user_id", "id"), outer=True)  # CHAINED outer join
+        query.add_join(("product_template", "product_category", "categ_id", "id", "categ_id"), implicit=False, outer=False)  # add normal join
+        query.add_join(("product_template__categ_id", "res_user", "user_id", "id", "user_id"), implicit=False, outer=True)  # CHAINED outer join
         self.assertEquals(query.get_sql()[0].strip(),
             """"product_product","product_template" JOIN "product_category" as "product_template__categ_id" ON ("product_template"."categ_id" = "product_template__categ_id"."id") LEFT JOIN "res_user" as "product_template__categ_id__user_id" ON ("product_template__categ_id"."user_id" = "product_template__categ_id__user_id"."id")""".strip())
         self.assertEquals(query.get_sql()[1].strip(), """product_product.template_id = product_template.id""".strip())
@@ -49,8 +49,8 @@ class QueryTestCase(unittest.TestCase):
         query = Query()
         query.tables.extend(['"product_product"', '"product_template"'])
         query.where_clause.append("product_product.template_id = product_template.id")
-        query.add_join(("product_template", "product_category", "categ_id", "id"), outer=False)  # add normal join
-        query.add_join(("product_template__categ_id", "res_user", "user_id", "id"), outer=True)  # CHAINED outer join
+        query.add_join(("product_template", "product_category", "categ_id", "id", "categ_id"), implicit=False, outer=False)  # add normal join
+        query.add_join(("product_template__categ_id", "res_user", "user_id", "id", "user_id"), implicit=False, outer=True)  # CHAINED outer join
         query.tables.append('"account.account"')
         query.where_clause.append("product_category.expense_account_id = account_account.id")  # additional implicit join
         self.assertEquals(query.get_sql()[0].strip(),
@@ -60,7 +60,7 @@ class QueryTestCase(unittest.TestCase):
     def test_raise_missing_lhs(self):
         query = Query()
         query.tables.append('"product_product"')
-        self.assertRaises(AssertionError, query.add_join, ("product_template", "product_category", "categ_id", "id"), outer=False)
+        self.assertRaises(AssertionError, query.add_join, ("product_template", "product_category", "categ_id", "id", "categ_id"), implicit=False, outer=False)
 
 
 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: