[MERGE]: Merge with latest trunk-server
[odoo/odoo.git] / openerp / tools / yaml_import.py
index 6600fb1..ed568b6 100644 (file)
@@ -11,6 +11,8 @@ import misc
 from config import config
 import yaml_tag
 import yaml
+import re
+from lxml import etree
 
 # YAML import needs both safe and unsafe eval, but let's
 # default to /safe/.
@@ -164,9 +166,11 @@ class YamlInterpreter(object):
             self.logger.log(logging.ERROR, 'id: %s is to long (max: 64)', id)
 
     def get_id(self, xml_id):
-        if not xml_id:
-            raise YamlImportException("The xml_id should be a non empty string.")
-        if isinstance(xml_id, types.IntType):
+        if xml_id is False:
+            return False
+        #if not xml_id:
+        #    raise YamlImportException("The xml_id should be a non empty string.")
+        elif isinstance(xml_id, types.IntType):
             id = xml_id
         elif xml_id in self.id_map:
             id = self.id_map[xml_id]
@@ -298,7 +302,7 @@ class YamlInterpreter(object):
 
     def create_osv_memory_record(self, record, fields):
         model = self.get_model(record.model)
-        record_dict = self._create_record(model, fields)
+        record_dict = self._create_record(model, fields, False)
         id_new=model.create(self.cr, self.uid, record_dict, context=self.context)
         self.id_map[record.id] = int(id_new)
         return record_dict
@@ -307,10 +311,21 @@ class YamlInterpreter(object):
         import openerp.osv as osv
         record, fields = node.items()[0]
         model = self.get_model(record.model)
+
+        view_id = record.view
+        if view_id and (view_id is not True):
+            view_id = self.pool.get('ir.model.data').get_object_reference(self.cr, 1, self.module, record.view)[1]
+
         if model.is_transient():
             record_dict=self.create_osv_memory_record(record, fields)
         else:
             self.validate_xml_id(record.id)
+            try:
+                self.pool.get('ir.model.data')._get_id(self.cr, 1, self.module, record.id)
+                default = False
+            except ValueError:
+                default = True
+
             if self.isnoupdate(record) and self.mode != 'init':
                 id = self.pool.get('ir.model.data')._update_dummy(self.cr, 1, record.model, self.module, record.id)
                 # check if the resource already existed at the last update
@@ -321,22 +336,110 @@ class YamlInterpreter(object):
                     if not self._coerce_bool(record.forcecreate):
                         return None
 
-            record_dict = self._create_record(model, fields)
-            self.logger.debug("RECORD_DICT %s" % record_dict)
+
             #context = self.get_context(record, self.eval_context)
-            context = record.context #TOFIX: record.context like {'withoutemployee':True} should pass from self.eval_context. example: test_project.yml in project module
+            #TOFIX: record.context like {'withoutemployee':True} should pass from self.eval_context. example: test_project.yml in project module
+            context = record.context
+            if view_id:
+                varg = view_id
+                if view_id is True: varg = False
+                view = model.fields_view_get(self.cr, 1, varg, 'form', context)
+                view_id = etree.fromstring(view['arch'].encode('utf-8'))
+
+            record_dict = self._create_record(model, fields, view_id, default=default)
+            self.logger.debug("RECORD_DICT %s" % record_dict)
             id = self.pool.get('ir.model.data')._update(self.cr, 1, record.model, \
                     self.module, record_dict, record.id, noupdate=self.isnoupdate(record), mode=self.mode, context=context)
             self.id_map[record.id] = int(id)
             if config.get('import_partial'):
                 self.cr.commit()
 
-    def _create_record(self, model, fields):
+    def _create_record(self, model, fields, view=False, parent={}, default=True):
+        allfields = model.fields_get(self.cr, 1, context=self.context)
+        if view is not False:
+            defaults = default and model.default_get(self.cr, 1, allfields, context=self.context) or {}
+            fg = model.fields_get(self.cr, 1, context=self.context)
+        else:
+            default = {}
+            fg = {}
         record_dict = {}
         fields = fields or {}
+
+        def process_val(key, val):
+            if fg[key]['type']=='many2one':
+                if type(val) in (tuple,list):
+                    val = val[0]
+            elif (fg[key]['type']=='one2many'):
+                if val is False:
+                    val = []
+                if len(val) and type(val[0]) == dict:
+                    val = map(lambda x: (0,0,x), val)
+            return val
+
+        # Process all on_change calls
+        nodes = (view is not False) and [view] or []
+        while nodes:
+            el = nodes.pop(0)
+            if el.tag=='field':
+                field_name = el.attrib['name']
+                assert field_name in fg, "The field '%s' is defined in the form view but not on the object '%s'!" % (field_name, model._name)
+                if field_name in fields:
+                    view2 = None
+                    # if the form view is not inline, we call fields_view_get
+                    if (view is not False) and (fg[field_name]['type']=='one2many'):
+                        view2 = view.find("field[@name='%s']/form"%(field_name,))
+                        if not view2:
+                            view2 = self.pool.get(fg[field_name]['relation']).fields_view_get(self.cr, 1, False, 'form', self.context)
+                            view2 = etree.fromstring(view2['arch'].encode('utf-8'))
+
+                    field_value = self._eval_field(model, field_name, fields[field_name], view2, parent=record_dict, default=default)
+                    record_dict[field_name] = field_value
+                    #if (field_name in defaults) and defaults[field_name] == field_value:
+                    #    print '*** You can remove these lines:', field_name, field_value
+                elif (field_name in defaults):
+                    if (field_name not in record_dict):
+                        record_dict[field_name] = process_val(field_name, defaults[field_name])
+                else:
+                    continue
+
+                if not el.attrib.get('on_change', False):
+                    continue
+                match = re.match("([a-z_1-9A-Z]+)\((.*)\)", el.attrib['on_change'])
+                assert match, "Unable to parse the on_change '%s'!" % (el.attrib['on_change'], )
+
+                # creating the context
+                class parent2(object):
+                    def __init__(self, d):
+                        self.d = d
+                    def __getattr__(self, name):
+                        return self.d.get(name, False)
+
+                ctx = record_dict.copy()
+                ctx['context'] = self.context
+                ctx['uid'] = 1
+                ctx['parent'] = parent2(parent)
+                for a in fg:
+                    if a not in ctx:
+                        ctx[a]=process_val(a, defaults.get(a, False))
+
+                # Evaluation args
+                args = map(lambda x: eval(x, ctx), match.group(2).split(','))
+                result = getattr(model, match.group(1))(self.cr, 1, [], *args)
+                for key, val in (result or {}).get('value', {}).items():
+                    if key not in fields:
+                        assert key in fg, "The returning field '%s' from your on_change call '%s' does not exist on the object '%s'" % (key, match.group(1), model._name)
+                        record_dict[key] = process_val(key, val)
+                        #if (key in fields) and record_dict[key] == process_val(key, val):
+                        #    print '*** You can remove these lines:', key, val
+            else:
+                nodes = list(el) + nodes
+
         for field_name, expression in fields.items():
-            field_value = self._eval_field(model, field_name, expression)
+            if field_name in record_dict:
+                continue
+            field_value = self._eval_field(model, field_name, expression, default=False)
             record_dict[field_name] = field_value
+
         return record_dict
 
     def process_ref(self, node, column=None):
@@ -365,7 +468,7 @@ class YamlInterpreter(object):
     def process_eval(self, node):
         return eval(node.expression, self.eval_context)
 
-    def _eval_field(self, model, field_name, expression):
+    def _eval_field(self, model, field_name, expression, view=False, parent={}, default=True):
         # TODO this should be refactored as something like model.get_field() in bin/osv
         if field_name in model._columns:
             column = model._columns[field_name]
@@ -386,7 +489,7 @@ class YamlInterpreter(object):
             value = self.get_id(expression)
         elif column._type == "one2many":
             other_model = self.get_model(column._obj)
-            value = [(0, 0, self._create_record(other_model, fields)) for fields in expression]
+            value = [(0, 0, self._create_record(other_model, fields, view, parent, default=default)) for fields in expression]
         elif column._type == "many2many":
             ids = [self.get_id(xml_id) for xml_id in expression]
             value = [(6, 0, ids)]