[MERGE] from trunk
[odoo/odoo.git] / openerp / osv / orm.py
index f324ed6..4f21a15 100644 (file)
@@ -42,6 +42,7 @@
 """
 
 import calendar
+import collections
 import copy
 import datetime
 import itertools
@@ -52,18 +53,21 @@ import re
 import simplejson
 import time
 import types
+
+import psycopg2
 from lxml import etree
+import warnings
 
 import fields
 import openerp
 import openerp.netsvc as netsvc
 import openerp.tools as tools
 from openerp.tools.config import config
+from openerp.tools.misc import CountingStream
 from openerp.tools.safe_eval import safe_eval as eval
 from openerp.tools.translate import _
 from openerp import SUPERUSER_ID
 from query import Query
-from openerp import SUPERUSER_ID
 
 _logger = logging.getLogger(__name__)
 _schema = logging.getLogger(__name__ + '.schema')
@@ -1214,7 +1218,11 @@ class BaseModel(object):
         return {'datas': datas}
 
     def import_data(self, cr, uid, fields, datas, mode='init', current_module='', noupdate=False, context=None, filename=None):
-        """Import given data in given module
+        """
+        .. deprecated:: 7.0
+            Use :meth:`~load` instead
+
+        Import given data in given module
 
         This method is used when importing data via client menu.
 
@@ -1242,7 +1250,7 @@ class BaseModel(object):
         * The last item is currently unused, with no specific semantics
 
         :param fields: list of fields to import
-        :param data: data to import
+        :param datas: data to import
         :param mode: 'init' or 'update' for record creation
         :param current_module: module name
         :param noupdate: flag for record creation
@@ -1250,194 +1258,270 @@ class BaseModel(object):
         :returns: 4-tuple in the form (return_code, errored_resource, error_message, unused)
         :rtype: (int, dict or 0, str or 0, str or 0)
         """
-        if not context:
-            context = {}
+        context = dict(context) if context is not None else {}
+        context['_import_current_module'] = current_module
+
         fields = map(fix_import_export_id_paths, fields)
         ir_model_data_obj = self.pool.get('ir.model.data')
 
-        # mode: id (XML id) or .id (database id) or False for name_get
-        def _get_id(model_name, id, current_module=False, mode='id'):
-            if mode=='.id':
-                id = int(id)
-                obj_model = self.pool.get(model_name)
-                ids = obj_model.search(cr, uid, [('id', '=', int(id))])
-                if not len(ids):
-                    raise Exception(_("Database ID doesn't exist: %s : %s") %(model_name, id))
-            elif mode=='id':
-                if '.' in id:
-                    module, xml_id = id.rsplit('.', 1)
-                else:
-                    module, xml_id = current_module, id
-                record_id = ir_model_data_obj._get_id(cr, uid, module, xml_id)
-                ir_model_data = ir_model_data_obj.read(cr, uid, [record_id], ['res_id'])
-                if not ir_model_data:
-                    raise ValueError('No references to %s.%s' % (module, xml_id))
-                id = ir_model_data[0]['res_id']
-            else:
-                obj_model = self.pool.get(model_name)
-                ids = obj_model.name_search(cr, uid, id, operator='=', context=context)
-                if not ids:
-                    raise ValueError('No record found for %s' % (id,))
-                id = ids[0][0]
-            return id
-
-        # IN:
-        #   datas: a list of records, each record is defined by a list of values
-        #   prefix: a list of prefix fields ['line_ids']
-        #   position: the line to process, skip is False if it's the first line of the current record
-        # OUT:
-        #   (res, position, warning, res_id) with
-        #     res: the record for the next line to process (including it's one2many)
-        #     position: the new position for the next line
-        #     res_id: the ID of the record if it's a modification
-        def process_liness(self, datas, prefix, current_module, model_name, fields_def, position=0, skip=0):
-            line = datas[position]
-            row = {}
-            warning = []
-            data_res_id = False
-            xml_id = False
-            nbrmax = position+1
-
-            done = {}
-            for i, field in enumerate(fields):
-                res = False
-                if i >= len(line):
-                    raise Exception(_('Please check that all your lines have %d columns.'
-                        'Stopped around line %d having %d columns.') % \
-                            (len(fields), position+2, len(line)))
-                if not line[i]:
-                    continue
-
-                if field[:len(prefix)] != prefix:
-                    if line[i] and skip:
-                        return False
-                    continue
-                field_name = field[len(prefix)]
-
-                #set the mode for m2o, o2m, m2m : xml_id/id/name
-                if len(field) == len(prefix)+1:
-                    mode = False
-                else:
-                    mode = field[len(prefix)+1]
-
-                # TODO: improve this by using csv.csv_reader
-                def many_ids(line, relation, current_module, mode):
-                    res = []
-                    for db_id in line.split(config.get('csv_internal_sep')):
-                        res.append(_get_id(relation, db_id, current_module, mode))
-                    return [(6,0,res)]
-
-                # ID of the record using a XML ID
-                if field_name == 'id':
-                    try:
-                        data_res_id = _get_id(model_name, line[i], current_module)
-                    except ValueError:
-                        pass
-                    xml_id = line[i]
-                    continue
-
-                # ID of the record using a database ID
-                elif field_name == '.id':
-                    data_res_id = _get_id(model_name, line[i], current_module, '.id')
-                    continue
-
-                field_type = fields_def[field_name]['type']
-                # recursive call for getting children and returning [(0,0,{})] or [(1,ID,{})]
-                if field_type == 'one2many':
-                    if field_name in done:
-                        continue
-                    done[field_name] = True
-                    relation = fields_def[field_name]['relation']
-                    relation_obj = self.pool.get(relation)
-                    newfd = relation_obj.fields_get( cr, uid, context=context )
-                    pos = position
-
-                    res = []
-
-                    first = 0
-                    while pos < len(datas):
-                        res2 = process_liness(self, datas, prefix + [field_name], current_module, relation_obj._name, newfd, pos, first)
-                        if not res2:
-                            break
-                        (newrow, pos, w2, data_res_id2, xml_id2) = res2
-                        nbrmax = max(nbrmax, pos)
-                        warning += w2
-                        first += 1
-
-                        if (not newrow) or not reduce(lambda x, y: x or y, newrow.values(), 0):
-                            break
-
-                        res.append( (data_res_id2 and 1 or 0, data_res_id2 or 0, newrow) )
-
-                elif field_type == 'many2one':
-                    relation = fields_def[field_name]['relation']
-                    res = _get_id(relation, line[i], current_module, mode)
-
-                elif field_type == 'many2many':
-                    relation = fields_def[field_name]['relation']
-                    res = many_ids(line[i], relation, current_module, mode)
-
-                elif field_type == 'integer':
-                    res = line[i] and int(line[i]) or 0
-                elif field_type == 'boolean':
-                    res = line[i].lower() not in ('0', 'false', 'off')
-                elif field_type == 'float':
-                    res = line[i] and float(line[i]) or 0.0
-                elif field_type == 'selection':
-                    for key, val in fields_def[field_name]['selection']:
-                        if tools.ustr(line[i]) in [tools.ustr(key), tools.ustr(val)]:
-                            res = key
-                            break
-                    if line[i] and not res:
-                        _logger.warning(
-                            _("key '%s' not found in selection field '%s'"),
-                            tools.ustr(line[i]), tools.ustr(field_name))
-                        warning.append(_("Key/value '%s' not found in selection field '%s'") % (
-                            tools.ustr(line[i]), tools.ustr(field_name)))
-
-                else:
-                    res = line[i]
-
-                row[field_name] = res or False
-
-            return row, nbrmax, warning, data_res_id, xml_id
+        def log(m):
+            if m['type'] == 'error':
+                raise Exception(m['message'])
 
-        fields_def = self.fields_get(cr, uid, context=context)
-
-        position = 0
         if config.get('import_partial') and filename:
             with open(config.get('import_partial'), 'rb') as partial_import_file:
                 data = pickle.load(partial_import_file)
                 position = data.get(filename, 0)
 
-        while position<len(datas):
-            (res, position, warning, res_id, xml_id) = \
-                    process_liness(self, datas, [], current_module, self._name, fields_def, position=position)
-            if len(warning):
-                cr.rollback()
-                return -1, res, 'Line ' + str(position) +' : ' + '!\n'.join(warning), ''
-
-            try:
+        position = 0
+        try:
+            for res_id, xml_id, res, info in self._convert_records(cr, uid,
+                            self._extract_records(cr, uid, fields, datas,
+                                                  context=context, log=log),
+                            context=context, log=log):
                 ir_model_data_obj._update(cr, uid, self._name,
                      current_module, res, mode=mode, xml_id=xml_id,
                      noupdate=noupdate, res_id=res_id, context=context)
-            except Exception, e:
-                return -1, res, 'Line ' + str(position) + ' : ' + tools.ustr(e), ''
-
-            if config.get('import_partial') and filename and (not (position%100)):
-                with open(config.get('import_partial'), 'rb') as partial_import:
-                    data = pickle.load(partial_import)
-                data[filename] = position
-                with open(config.get('import_partial'), 'wb') as partial_import:
-                    pickle.dump(data, partial_import)
-                if context.get('defer_parent_store_computation'):
-                    self._parent_store_compute(cr)
-                cr.commit()
+                position = info.get('rows', {}).get('to', 0) + 1
+                if config.get('import_partial') and filename and (not (position%100)):
+                    with open(config.get('import_partial'), 'rb') as partial_import:
+                        data = pickle.load(partial_import)
+                    data[filename] = position
+                    with open(config.get('import_partial'), 'wb') as partial_import:
+                        pickle.dump(data, partial_import)
+                    if context.get('defer_parent_store_computation'):
+                        self._parent_store_compute(cr)
+                    cr.commit()
+        except Exception, e:
+            cr.rollback()
+            return -1, {}, 'Line %d : %s' % (position + 1, tools.ustr(e)), ''
 
         if context.get('defer_parent_store_computation'):
             self._parent_store_compute(cr)
         return position, 0, 0, 0
 
+    def load(self, cr, uid, fields, data, context=None):
+        """
+        Attempts to load the data matrix, and returns a list of ids (or
+        ``False`` if there was an error and no id could be generated) and a
+        list of messages.
+
+        The ids are those of the records created and saved (in database), in
+        the same order they were extracted from the file. They can be passed
+        directly to :meth:`~read`
+
+        :param cr: cursor for the request
+        :param int uid: ID of the user attempting the data import
+        :param fields: list of fields to import, at the same index as the corresponding data
+        :type fields: list(str)
+        :param data: row-major matrix of data to import
+        :type data: list(list(str))
+        :param dict context:
+        :returns: {ids: list(int)|False, messages: [Message]}
+        """
+        cr.execute('SAVEPOINT model_load')
+        messages = []
+
+        fields = map(fix_import_export_id_paths, fields)
+        ModelData = self.pool['ir.model.data']
+        fg = self.fields_get(cr, uid, context=context)
+
+        mode = 'init'
+        current_module = ''
+        noupdate = False
+
+        ids = []
+        for id, xid, record, info in self._convert_records(cr, uid,
+                self._extract_records(cr, uid, fields, data,
+                                      context=context, log=messages.append),
+                context=context, log=messages.append):
+            try:
+                cr.execute('SAVEPOINT model_load_save')
+            except psycopg2.InternalError, e:
+                # broken transaction, exit and hope the source error was
+                # already logged
+                if not any(message['type'] == 'error' for message in messages):
+                    messages.append(dict(info, type='error',message=
+                        u"Unknown database error: '%s'" % e))
+                break
+            try:
+                ids.append(ModelData._update(cr, uid, self._name,
+                     current_module, record, mode=mode, xml_id=xid,
+                     noupdate=noupdate, res_id=id, context=context))
+                cr.execute('RELEASE SAVEPOINT model_load_save')
+            except psycopg2.Warning, e:
+                cr.execute('ROLLBACK TO SAVEPOINT model_load_save')
+                messages.append(dict(info, type='warning', message=str(e)))
+            except psycopg2.Error, e:
+                # Failed to write, log to messages, rollback savepoint (to
+                # avoid broken transaction) and keep going
+                cr.execute('ROLLBACK TO SAVEPOINT model_load_save')
+                messages.append(dict(
+                    info, type='error',
+                    **PGERROR_TO_OE[e.pgcode](self, fg, info, e)))
+        if any(message['type'] == 'error' for message in messages):
+            cr.execute('ROLLBACK TO SAVEPOINT model_load')
+            ids = False
+        return {'ids': ids, 'messages': messages}
+    def _extract_records(self, cr, uid, fields_, data,
+                         context=None, log=lambda a: None):
+        """ Generates record dicts from the data sequence.
+
+        The result is a generator of dicts mapping field names to raw
+        (unconverted, unvalidated) values.
+
+        For relational fields, if sub-fields were provided the value will be
+        a list of sub-records
+
+        The following sub-fields may be set on the record (by key):
+        * None is the name_get for the record (to use with name_create/name_search)
+        * "id" is the External ID for the record
+        * ".id" is the Database ID for the record
+
+        :param ImportLogger logger:
+        """
+        columns = dict((k, v.column) for k, v in self._all_columns.iteritems())
+        # Fake columns to avoid special cases in extractor
+        columns[None] = fields.char('rec_name')
+        columns['id'] = fields.char('External ID')
+        columns['.id'] = fields.integer('Database ID')
+
+        # m2o fields can't be on multiple lines so exclude them from the
+        # is_relational field rows filter, but special-case it later on to
+        # be handled with relational fields (as it can have subfields)
+        is_relational = lambda field: columns[field]._type in ('one2many', 'many2many', 'many2one')
+        get_o2m_values = itemgetter_tuple(
+            [index for index, field in enumerate(fields_)
+                  if columns[field[0]]._type == 'one2many'])
+        get_nono2m_values = itemgetter_tuple(
+            [index for index, field in enumerate(fields_)
+                  if columns[field[0]]._type != 'one2many'])
+        # Checks if the provided row has any non-empty non-relational field
+        def only_o2m_values(row, f=get_nono2m_values, g=get_o2m_values):
+            return any(g(row)) and not any(f(row))
+
+        index = 0
+        while True:
+            if index >= len(data): return
+
+            row = data[index]
+            # copy non-relational fields to record dict
+            record = dict((field[0], value)
+                for field, value in itertools.izip(fields_, row)
+                if not is_relational(field[0]))
+
+            # Get all following rows which have relational values attached to
+            # the current record (no non-relational values)
+            record_span = itertools.takewhile(
+                only_o2m_values, itertools.islice(data, index + 1, None))
+            # stitch record row back on for relational fields
+            record_span = list(itertools.chain([row], record_span))
+            for relfield in set(
+                    field[0] for field in fields_
+                             if is_relational(field[0])):
+                column = columns[relfield]
+                # FIXME: how to not use _obj without relying on fields_get?
+                Model = self.pool[column._obj]
+
+                # get only cells for this sub-field, should be strictly
+                # non-empty, field path [None] is for name_get column
+                indices, subfields = zip(*((index, field[1:] or [None])
+                                           for index, field in enumerate(fields_)
+                                           if field[0] == relfield))
+
+                # return all rows which have at least one value for the
+                # subfields of relfield
+                relfield_data = filter(any, map(itemgetter_tuple(indices), record_span))
+                record[relfield] = [subrecord
+                    for subrecord, _subinfo in Model._extract_records(
+                        cr, uid, subfields, relfield_data,
+                        context=context, log=log)]
+
+            yield record, {'rows': {
+                'from': index,
+                'to': index + len(record_span) - 1
+            }}
+            index += len(record_span)
+    def _convert_records(self, cr, uid, records,
+                         context=None, log=lambda a: None):
+        """ Converts records from the source iterable (recursive dicts of
+        strings) into forms which can be written to the database (via
+        self.create or (ir.model.data)._update)
+
+        :param ImportLogger parent_logger:
+        :returns: a list of triplets of (id, xid, record)
+        :rtype: list((int|None, str|None, dict))
+        """
+        if context is None: context = {}
+        Converter = self.pool['ir.fields.converter']
+        columns = dict((k, v.column) for k, v in self._all_columns.iteritems())
+        Translation = self.pool['ir.translation']
+        field_names = dict(
+            (f, (Translation._get_source(cr, uid, self._name + ',' + f, 'field',
+                                         context.get('lang', False) or 'en_US')
+                 or column.string or f))
+            for f, column in columns.iteritems())
+        converters = dict(
+            (k, Converter.to_field(cr, uid, self, column, context=context))
+            for k, column in columns.iteritems())
+
+        def _log(base, field, exception):
+            type = 'warning' if isinstance(exception, Warning) else 'error'
+            record = dict(base, field=field, type=type,
+                          message=unicode(exception.args[0]) % base)
+            if len(exception.args) > 1 and exception.args[1]:
+                record.update(exception.args[1])
+            log(record)
+
+        stream = CountingStream(records)
+        for record, extras in stream:
+            dbid = False
+            xid = False
+            converted = {}
+            # name_get/name_create
+            if None in record: pass
+            # xid
+            if 'id' in record:
+                xid = record['id']
+            # dbid
+            if '.id' in record:
+                try:
+                    dbid = int(record['.id'])
+                except ValueError:
+                    # in case of overridden id column
+                    dbid = record['.id']
+                if not self.search(cr, uid, [('id', '=', dbid)], context=context):
+                    log(dict(extras,
+                        type='error',
+                        record=stream.index,
+                        field='.id',
+                        message=_(u"Unknown database identifier '%s'") % dbid))
+                    dbid = False
+
+            for field, strvalue in record.iteritems():
+                if field in (None, 'id', '.id'): continue
+                if not strvalue:
+                    converted[field] = False
+                    continue
+
+                # In warnings and error messages, use translated string as
+                # field name
+                message_base = dict(
+                    extras, record=stream.index, field=field_names[field])
+                try:
+                    converted[field], ws = converters[field](strvalue)
+
+                    for w in ws:
+                        if isinstance(w, basestring):
+                            # wrap warning string in an ImportWarning for
+                            # uniform handling
+                            w = ImportWarning(w)
+                        _log(message_base, field, w)
+                except ValueError, e:
+                    _log(message_base, field, e)
+
+            yield dbid, xid, converted, dict(extras, record=stream.index)
+
     def get_invalid_fields(self, cr, uid):
         return list(self._invalids)
 
@@ -4974,7 +5058,7 @@ class BaseModel(object):
     def is_transient(self):
         """ Return whether the model is transient.
 
-        See TransientModel.
+        See :class:`TransientModel`.
 
         """
         return self._transient
@@ -5116,5 +5200,39 @@ class AbstractModel(BaseModel):
     _auto = False # don't create any database backend for AbstractModels
     _register = False # not visible in ORM registry, meant to be python-inherited only
 
+def itemgetter_tuple(items):
+    """ Fixes itemgetter inconsistency (useful in some cases) of not returning
+    a tuple if len(items) == 1: always returns an n-tuple where n = len(items)
+    """
+    if len(items) == 0:
+        return lambda a: ()
+    if len(items) == 1:
+        return lambda gettable: (gettable[items[0]],)
+    return operator.itemgetter(*items)
+class ImportWarning(Warning):
+    """ Used to send warnings upwards the stack during the import process
+    """
+    pass
+
 
+def convert_pgerror_23502(model, fields, info, e):
+    m = re.match(r'^null value in column "(?P<field>\w+)" violates '
+                 r'not-null constraint\n',
+                 str(e))
+    if not m or m.group('field') not in fields:
+        return {'message': unicode(e)}
+    field = fields[m.group('field')]
+    return {
+        'message': _(u"Missing required value for the field '%(field)s'") % {
+            'field': field['string']
+        },
+        'field': m.group('field'),
+    }
+
+PGERROR_TO_OE = collections.defaultdict(
+    # shape of mapped converters
+    lambda: (lambda model, fvg, info, pgerror: {'message': unicode(pgerror)}), {
+    # not_null_violation
+    '23502': convert_pgerror_23502,
+})
 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: