improvement
authorFabien Pinckaers <fp@tinyerp.com>
Sun, 7 Dec 2008 02:15:41 +0000 (03:15 +0100)
committerFabien Pinckaers <fp@tinyerp.com>
Sun, 7 Dec 2008 02:15:41 +0000 (03:15 +0100)
bzr revid: fp@tinyerp.com-20081207021541-hoxme31i15s1aiog

addons/account/account.py
addons/account/account_view.xml
addons/account/invoice.py
addons/account_tax_include/invoice_tax_incl.py

index 1b618d2..1b7727e 100644 (file)
@@ -1232,6 +1232,7 @@ class account_tax(osv.osv):
         'include_base_amount': fields.boolean('Include in base amount', help="Indicate if the amount of tax must be included in the base amount for the computation of the next taxes"),
         'company_id': fields.many2one('res.company', 'Company', required=True),
         'description': fields.char('Internal Name',size=32),
+        'price_include': fields.boolean('Tax Included in Price', help="Check this is the price you use on the product and invoices is including this tax.")
     }
 
     def name_get(self, cr, uid, ids, context={}):
@@ -1254,6 +1255,7 @@ class account_tax(osv.osv):
         'applicable_type': lambda *a: 'true',
         'type': lambda *a: 'percent',
         'amount': lambda *a: 0,
+        'price_include': lambda *a: 0,
         'active': lambda *a: 1,
         'sequence': lambda *a: 1,
         'tax_group': lambda *a: 'vat',
@@ -1286,11 +1288,8 @@ class account_tax(osv.osv):
         for tax in taxes:
             # we compute the amount for the current tax object and append it to the result
 
-            if tax.type=='percent':
-                amount = cur_price_unit * tax.amount
-                res.append({'id':tax.id,
+            data = {'id':tax.id,
                             'name':tax.name,
-                            'amount':amount,
                             'account_collected_id':tax.account_collected_id.id,
                             'account_paid_id':tax.account_paid_id.id,
                             'base_code_id': tax.base_code_id.id,
@@ -1303,47 +1302,21 @@ class account_tax(osv.osv):
                             'price_unit': cur_price_unit,
                             'tax_code_id': tax.tax_code_id.id,
                             'ref_tax_code_id': tax.ref_tax_code_id.id,
-                            })
+            }
+            res.append(data)
+            if tax.type=='percent':
+                amount = cur_price_unit * tax.amount
+                data['amount'] = amount
 
             elif tax.type=='fixed':
-                res.append({'id':tax.id,
-                            'name':tax.name,
-                            'amount':tax.amount,
-                            'account_collected_id':tax.account_collected_id.id,
-                            'account_paid_id':tax.account_paid_id.id,
-                            'base_code_id': tax.base_code_id.id,
-                            'ref_base_code_id': tax.ref_base_code_id.id,
-                            'sequence': tax.sequence,
-                            'base_sign': tax.base_sign,
-                            'tax_sign': tax.tax_sign,
-                            'ref_base_sign': tax.ref_base_sign,
-                            'ref_tax_sign': tax.ref_tax_sign,
-                            'price_unit': 1,
-                            'tax_code_id': tax.tax_code_id.id,
-                            'ref_tax_code_id': tax.ref_tax_code_id.id,})
+                data['amount'] = tax.amount
             elif tax.type=='code':
                 address = address_id and self.pool.get('res.partner.address').browse(cr, uid, address_id) or None
                 localdict = {'price_unit':cur_price_unit, 'address':address, 'product':product, 'partner':partner}
                 exec tax.python_compute in localdict
                 amount = localdict['result']
-                res.append({
-                    'id': tax.id,
-                    'name': tax.name,
-                    'amount': amount,
-                    'account_collected_id': tax.account_collected_id.id,
-                    'account_paid_id': tax.account_paid_id.id,
-                    'base_code_id': tax.base_code_id.id,
-                    'ref_base_code_id': tax.ref_base_code_id.id,
-                    'sequence': tax.sequence,
-                    'base_sign': tax.base_sign,
-                    'tax_sign': tax.tax_sign,
-                    'ref_base_sign': tax.ref_base_sign,
-                    'ref_tax_sign': tax.ref_tax_sign,
-                    'price_unit': cur_price_unit,
-                    'tax_code_id': tax.tax_code_id.id,
-                    'ref_tax_code_id': tax.ref_tax_code_id.id,
-                })
-            amount2 = res[-1]['amount']
+                data['amount'] = amount
+            amount2 = data['amount']
             if len(tax.child_ids):
                 if tax.child_depend:
                     del res[-1]
@@ -1369,86 +1342,72 @@ class account_tax(osv.osv):
             r['amount'] *= quantity
         return res
 
-    def _unit_compute_inv(self, cr, uid, taxes, price_unit, address_id=None, product=None, partner=None,tax_parent_tot=0.0):
+    def _unit_compute_inv(self, cr, uid, taxes, price_unit, address_id=None, product=None, partner=None):
         taxes = self._applicable(cr, uid, taxes, price_unit, address_id, product, partner)
 
         res = []
         taxes.reverse()
         cur_price_unit=price_unit
+
+        tax_parent_tot = 0.0
         for tax in taxes:
-            # we compute the amount for the current tax object and append it to the result
+            if (tax.type=='percent') and not tax.include_base_amount:
+                tax_parent_tot+=tax.amount
 
+        for tax in taxes:
             if tax.type=='percent':
-                amount = cur_price_unit - (cur_price_unit / (1 + tax.amount))
-                res.append({'id':tax.id,
-                            'name':tax.name,
-                            'amount':amount,
-                            'account_collected_id':tax.account_collected_id.id,
-                            'account_paid_id':tax.account_paid_id.id,
-                            'base_code_id': tax.base_code_id.id,
-                            'ref_base_code_id': tax.ref_base_code_id.id,
-                            'sequence': tax.sequence,
-                            'base_sign': tax.base_sign,
-                            'tax_sign': tax.tax_sign,
-                            'ref_base_sign': tax.ref_base_sign,
-                            'ref_tax_sign': tax.ref_tax_sign,
-                            'price_unit': cur_price_unit - amount,
-                            'tax_code_id': tax.tax_code_id.id,
-                            'ref_tax_code_id': tax.ref_tax_code_id.id,})
+                if tax.include_base_amount:
+                    amount = cur_price_unit - (cur_price_unit / (1 + tax.amount))
+                else:
+                    amount = (cur_price_unit / (1 + tax_parent_tot)) * tax.amount
 
             elif tax.type=='fixed':
-                res.append({'id':tax.id,
-                            'name':tax.name,
-                            'amount':tax.amount,
-                            'account_collected_id':tax.account_collected_id.id,
-                            'account_paid_id':tax.account_paid_id.id,
-                            'base_code_id': tax.base_code_id.id,
-                            'ref_base_code_id': tax.ref_base_code_id.id,
-                            'sequence': tax.sequence,
-                            'base_sign': tax.base_sign,
-                            'tax_sign': tax.tax_sign,
-                            'ref_base_sign': tax.ref_base_sign,
-                            'ref_tax_sign': tax.ref_tax_sign,
-                            'price_unit': 1,
-                            'tax_code_id': tax.tax_code_id.id,
-                            'ref_tax_code_id': tax.ref_tax_code_id.id,})
+                amount = tax.amount
 
             elif tax.type=='code':
                 address = address_id and self.pool.get('res.partner.address').browse(cr, uid, address_id) or None
                 localdict = {'price_unit':cur_price_unit, 'address':address, 'product':product, 'partner':partner}
                 exec tax.python_compute_inv in localdict
                 amount = localdict['result']
-                res.append({
-                    'id': tax.id,
-                    'name': tax.name,
-                    'amount': amount,
-                    'account_collected_id': tax.account_collected_id.id,
-                    'account_paid_id': tax.account_paid_id.id,
-                    'base_code_id': tax.base_code_id.id,
-                    'ref_base_code_id': tax.ref_base_code_id.id,
-                    'sequence': tax.sequence,
-                    'base_sign': tax.base_sign,
-                    'tax_sign': tax.tax_sign,
-                    'ref_base_sign': tax.ref_base_sign,
-                    'ref_tax_sign': tax.ref_tax_sign,
-                    'price_unit': cur_price_unit - amount,
-                    'tax_code_id': tax.tax_code_id.id,
-                    'ref_tax_code_id': tax.ref_tax_code_id.id,
-                })
 
-            amount2 = res[-1]['amount']
+            if tax.include_base_amount:
+                cur_price_unit -= amount
+                todo = 0
+            else:
+                todo = 1
+            res.append({
+                'id': tax.id,
+                'todo': todo,
+                'name': tax.name,
+                'amount': amount,
+                'account_collected_id': tax.account_collected_id.id,
+                'account_paid_id': tax.account_paid_id.id,
+                'base_code_id': tax.base_code_id.id,
+                'ref_base_code_id': tax.ref_base_code_id.id,
+                'sequence': tax.sequence,
+                'base_sign': tax.base_sign,
+                'tax_sign': tax.tax_sign,
+                'ref_base_sign': tax.ref_base_sign,
+                'ref_tax_sign': tax.ref_tax_sign,
+                'price_unit': cur_price_unit,
+                'tax_code_id': tax.tax_code_id.id,
+                'ref_tax_code_id': tax.ref_tax_code_id.id,
+            })
             if len(tax.child_ids):
                 if tax.child_depend:
                     del res[-1]
                     amount = price_unit
-                else:
-                    amount = amount2
-            for t in tax.child_ids:
-                parent_tax = self._unit_compute_inv(cr, uid, [t], amount, address_id, product, partner)
-                res.extend(parent_tax)
-            if tax.include_base_amount:
-                cur_price_unit-=amount
-        taxes.reverse()
+
+            parent_tax = self._unit_compute_inv(cr, uid, tax.child_ids, amount, address_id, product, partner)
+            res.extend(parent_tax)
+
+        total = 0.0
+        for r in res:
+            if r['todo']:
+                total += r['amount']
+        for r in res:
+            r['price_unit'] -= total
+            r['todo'] = 0
         return res
 
     def compute_inv(self, cr, uid, taxes, price_unit, quantity, address_id=None, product=None, partner=None):
index c305351..c73b829 100644 (file)
             <field name="arch" type="xml">
                 <tree string="Account Tax">
                     <field name="name"/>
+                    <field name="price_include"/>
                     <field name="description"/>
                 </tree>
             </field>
                             <label colspan="2" nolabel="1" string="Keep empty to use the expense account"/>
                             <field groups="base.group_extended" name="child_depend"/>
                             <field groups="base.group_extended" name="sequence"/>
+                            <field groups="base.group_extended" name="price_include"/>
                             <newline/>
                             <field colspan="4" groups="base.group_extended" name="child_ids"/>
                         </page>
index a511d79..a67653b 100644 (file)
@@ -31,24 +31,19 @@ from tools import config
 from tools.translate import _
 
 class account_invoice(osv.osv):
-    def _amount_untaxed(self, cr, uid, ids, name, args, context={}):
-        id_set=",".join(map(str,ids))
-        cr.execute("SELECT s.id,COALESCE(SUM(l.price_subtotal),0)::decimal(16,2) AS amount FROM account_invoice s LEFT OUTER JOIN account_invoice_line l ON (s.id=l.invoice_id) WHERE s.id IN ("+id_set+") GROUP BY s.id ")
-        res=dict(cr.fetchall())
-        return res
-
-    def _amount_tax(self, cr, uid, ids, name, args, context={}):
-        id_set=",".join(map(str,ids))
-        cr.execute("SELECT s.id,COALESCE(SUM(l.amount),0)::decimal(16,2) AS amount FROM account_invoice s LEFT OUTER JOIN account_invoice_tax l ON (s.id=l.invoice_id) WHERE s.id IN ("+id_set+") GROUP BY s.id ")
-        res=dict(cr.fetchall())
-        return res
-
-    def _amount_total(self, cr, uid, ids, name, args, context={}):
-        untax = self._amount_untaxed(cr, uid, ids, name, args, context)
-        tax = self._amount_tax(cr, uid, ids, name, args, context)
+    def _amount_all(self, cr, uid, ids, name, args, context={}):
         res = {}
-        for id in ids:
-            res[id] = untax.get(id,0.0) + tax.get(id,0.0)
+        for invoice in self.browse(cr,uid,ids):
+            res[invoice.id] = {
+                'amount_untaxed': 0.0,
+                'amount_tax': 0.0,
+                'amount_total': 0.0
+            }
+            for line in invoice.invoice_line:
+                res[invoice.id]['amount_untaxed'] += line.price_subtotal
+            for line in invoice.tax_line:
+                res[invoice.id]['amount_tax'] += line.amount
+            res[invoice.id]['amount_total'] = res[invoice.id]['amount_tax'] + res[invoice.id]['amount_untaxed']
         return res
 
     def _get_journal(self, cr, uid, context):
@@ -118,6 +113,12 @@ class account_invoice(osv.osv):
                 res[id]=[x for x in l if x <> line.id]
         return res
 
+    def _get_invoice_tax(self, cr, uid, ids, context={}):
+        result = {}
+        for tax in self.pool.get('account.invoice.tax').browse(cr, uid, ids, context=context):
+            result[tax.invoice_id.id] = True
+        return result.keys()
+
     def _compute_lines(self, cr, uid, ids, name, args, context={}):
         result = {}
         for invoice in self.browse(cr, uid, ids, context):
@@ -178,14 +179,30 @@ class account_invoice(osv.osv):
         'tax_line': fields.one2many('account.invoice.tax', 'invoice_id', 'Tax Lines', readonly=True, states={'draft':[('readonly',False)]}),
 
         'move_id': fields.many2one('account.move', 'Invoice Movement', readonly=True, help="Link to the automatically generated account moves."),
-        'amount_untaxed': fields.function(_amount_untaxed, method=True, digits=(16,2),string='Untaxed', store=True),
-        'amount_tax': fields.function(_amount_tax, method=True, digits=(16,2), string='Tax', store=True),
-        'amount_total': fields.function(_amount_total, method=True, digits=(16,2), string='Total', store=True),
+        'amount_untaxed': fields.function(_amount_all, method=True, digits=(16,2),string='Untaxed',
+            store={
+                'account.invoice': (lambda self, cr, uid, ids, c={}: ids, None, 10),
+                'account.invoice.tax': (_get_invoice_tax, None, 10),
+            },
+            multi='all'),
+        'amount_tax': fields.function(_amount_all, method=True, digits=(16,2), string='Tax',
+            store={
+                'account.invoice': (lambda self, cr, uid, ids, c={}: ids, None, 10),
+                'account.invoice.tax': (_get_invoice_tax, None, 10),
+            },
+            multi='all'),
+        'amount_total': fields.function(_amount_all, method=True, digits=(16,2), string='Total',
+            store={
+                'account.invoice': (lambda self, cr, uid, ids, c={}: ids, None, 10),
+                'account.invoice.tax': (_get_invoice_tax, None, 10),
+            },
+            multi='all'),
         'currency_id': fields.many2one('res.currency', 'Currency', required=True, readonly=True, states={'draft':[('readonly',False)]}),
         'journal_id': fields.many2one('account.journal', 'Journal', required=True,readonly=True, states={'draft':[('readonly',False)]}),
         'company_id': fields.many2one('res.company', 'Company', required=True),
         'check_total': fields.float('Total', digits=(16,2), states={'open':[('readonly',True)],'close':[('readonly',True)]}),
-        'reconciled': fields.function(_reconciled, method=True, string='Paid/Reconciled', type='boolean', store=True, help="The account moves of the invoice have been reconciled with account moves of the payment(s)."),
+        'reconciled': fields.function(_reconciled, method=True, string='Paid/Reconciled', type='boolean',
+            store=True, help="The account moves of the invoice have been reconciled with account moves of the payment(s)."),
         'partner_bank': fields.many2one('res.partner.bank', 'Bank Account',
             help='The bank account to pay to or to be paid from'),
         'move_lines':fields.function(_get_lines , method=True,type='many2many' , relation='account.move.line',string='Move Lines'),
index 66ec0ec..be2c1e3 100644 (file)
@@ -26,42 +26,12 @@ from osv import fields, osv
 import ir
 
 class account_invoice(osv.osv):
-    def _amount_untaxed(self, cr, uid, ids, name, args, context={}):
-        res = {}
-        for invoice in self.browse(cr,uid,ids):
-            if invoice.price_type == 'tax_included':
-                res[invoice.id] = reduce( lambda x, y: x+y.price_subtotal, invoice.invoice_line,0)
-            else:
-                res[invoice.id] = super(account_invoice, self)._amount_untaxed(cr, uid, [invoice.id], name, args, context)[invoice.id]
-        return res
-
-    def _amount_tax(self, cr, uid, ids, name, args, context={}):
-        res = {}
-        for invoice in self.browse(cr,uid,ids):
-            if invoice.price_type == 'tax_included':
-                res[invoice.id] = reduce( lambda x, y: x+y.amount, invoice.tax_line,0)
-            else:
-                res[invoice.id] = super(account_invoice, self)._amount_tax(cr, uid, [invoice.id], name, args, context)[invoice.id]
-        return res
-
-    def _amount_total(self, cr, uid, ids, name, args, context={}):
-        res = {}
-        for invoice in self.browse(cr,uid,ids):
-            if invoice.price_type == 'tax_included':
-                res[invoice.id]= invoice.amount_untaxed + invoice.amount_tax
-            else:
-                res[invoice.id] = super(account_invoice, self)._amount_total(cr, uid, [invoice.id], name, args, context)[invoice.id]
-        return res
-
     _inherit = "account.invoice"
     _columns = {
         'price_type': fields.selection([('tax_included','Tax included'),
                                         ('tax_excluded','Tax excluded')],
                                         'Price method', required=True, readonly=True,
                                         states={'draft':[('readonly',False)]}),
-        'amount_untaxed': fields.function(_amount_untaxed, digits=(16,2), method=True,string='Untaxed Amount'),
-        'amount_tax': fields.function(_amount_tax, method=True, string='Tax', store=True),
-        'amount_total': fields.function(_amount_total, method=True, string='Total', store=True),
     }
     _defaults = {
         'price_type': lambda *a: 'tax_excluded',
@@ -70,47 +40,51 @@ account_invoice()
 
 class account_invoice_line(osv.osv):
     _inherit = "account.invoice.line"
-    def _amount_line(self, cr, uid, ids, name, args, context={}):
+    def _amount_line2(self, cr, uid, ids, name, args, context={}):
         """
         Return the subtotal excluding taxes with respect to price_type.
         """
         res = {}
         tax_obj = self.pool.get('account.tax')
-        res = super(account_invoice_line, self)._amount_line(cr, uid, ids, name, args, context)
-        res2 = res.copy()
+        res_init = super(account_invoice_line, self)._amount_line(cr, uid, ids, name, args, context)
         for line in self.browse(cr, uid, ids):
+            res[line.id] = {
+                'price_subtotal': 0.0,
+                'price_subtotal_incl': 0.0,
+                'data': []
+            }
             if not line.quantity:
-                res[line.id] = 0.0
                 continue
             if line.invoice_id and line.invoice_id.price_type == 'tax_included':
-                product_taxes = None
+                product_taxes = []
                 if line.product_id:
                     if line.invoice_id.type in ('out_invoice', 'out_refund'):
-                        product_taxes = line.product_id.taxes_id
+                        product_taxes = filter(lambda x: x.price_include, line.product_id.taxes_id)
                     else:
-                        product_taxes = line.product_id.supplier_taxes_id
-                if product_taxes:
-                    for tax in tax_obj.compute_inv(cr, uid, product_taxes, res[line.id]/line.quantity, line.quantity):
-                        res[line.id] = res[line.id] - round(tax['amount'], 2)
-                else:
-                    for tax in tax_obj.compute_inv(cr, uid,line.invoice_line_tax_id, res[line.id]/line.quantity, line.quantity):
-                        res[line.id] = res[line.id] - round(tax['amount'], 2)
-            if name == 'price_subtotal_incl' and line.invoice_id and line.invoice_id.price_type == 'tax_included':
-                prod_taxe_ids = None
-                line_taxe_ids = None
-                if product_taxes:
-                    prod_taxe_ids = [ t.id for t in product_taxes ]
-                    prod_taxe_ids.sort()
-                    line_taxe_ids = [ t.id for t in line.invoice_line_tax_id ]
-                    line_taxe_ids.sort()
-                if product_taxes and prod_taxe_ids == line_taxe_ids:
-                    res[line.id] = res2[line.id]
-                elif not line.product_id:
-                    res[line.id] = res2[line.id]
+                        product_taxes = filter(lambda x: x.price_include, line.product_id.supplier_taxes_id)
+
+                if (set(product_taxes) == set(line.invoice_line_tax_id)) or not product_taxes:
+                    res[line.id]['price_subtotal_incl'] = res_init[line.id]
                 else:
-                    for tax in tax_obj.compute(cr, uid, line.invoice_line_tax_id, res[line.id]/line.quantity, line.quantity):
-                        res[line.id] = res[line.id] + tax['amount']
-            res[line.id]= round(res[line.id], 2)
+                    res[line.id]['price_subtotal'] = res_init[line.id]
+                    for tax in tax_obj.compute_inv(cr, uid, product_taxes, res_init[line.id]/line.quantity, line.quantity):
+                        res[line.id]['price_subtotal'] = res[line.id]['price_subtotal'] - round(tax['amount'], 2)
+            else:
+                res[line.id]['price_subtotal'] = res_init[line.id]
+
+            if res[line.id]['price_subtotal']:
+                res[line.id]['price_subtotal_incl'] = res[line.id]['price_subtotal']
+                for tax in tax_obj.compute(cr, uid, line.invoice_line_tax_id, res[line.id]['price_subtotal']/line.quantity, line.quantity):
+                    res[line.id]['price_subtotal_incl'] = res[line.id]['price_subtotal_incl'] + tax['amount']
+                    res[line.id]['data'].append( tax)
+            else:
+                res[line.id]['price_subtotal'] = res[line.id]['price_subtotal_incl']
+                for tax in tax_obj.compute_inv(cr, uid, line.invoice_line_tax_id, res[line.id]['price_subtotal_incl']/line.quantity, line.quantity):
+                    res[line.id]['price_subtotal'] = res[line.id]['price_subtotal'] - tax['amount']
+                    res[line.id]['data'].append( tax)
+
+        res[line.id]['price_subtotal']= round(res[line.id]['price_subtotal'], 2)
+        res[line.id]['price_subtotal_incl']= round(res[line.id]['price_subtotal_incl'], 2)
         return res
 
     def _price_unit_default(self, cr, uid, context={}):
@@ -125,9 +99,17 @@ class account_invoice_line(osv.osv):
             return super(account_invoice_line, self)._price_unit_default(cr, uid, context)
         return 0
 
+    def _get_invoice(self, cr, uid, ids, context):
+        result = {}
+        for inv in self.pool.get('account.invoice').browse(cr, uid, ids, context=context):
+            for line in inv.invoice_line:
+                result[line.id] = True
+        return result.keys()
     _columns = {
-        'price_subtotal': fields.function(_amount_line, method=True, string='Subtotal w/o tax', store=True),
-        'price_subtotal_incl': fields.function(_amount_line, method=True, string='Subtotal'),
+        'price_subtotal': fields.function(_amount_line2, method=True, string='Subtotal w/o tax', multi='amount',
+            store={'account.invoice':(_get_invoice,None), 'account.invoice.line': (lambda self,cr,uid,ids,c={}: ids, None)}),
+        'price_subtotal_incl': fields.function(_amount_line2, method=True, string='Subtotal', multi='amount',
+            store={'account.invoice':(_get_invoice,None), 'account.invoice.line': (lambda self,cr,uid,ids,c={}: ids, None)}),
     }
 
     _defaults = {
@@ -147,32 +129,36 @@ class account_invoice_line(osv.osv):
                 'account_analytic_id':line.account_analytic_id.id,
             }
 
-    def product_id_change_unit_price_inv(self, cr, uid, tax_id, price_unit, qty, address_invoice_id, product, partner_id, context={}):
-        if context.get('price_type', False) == 'tax_included':
-            return {'price_unit': price_unit,'invoice_line_tax_id': tax_id}
-        else:
-            return super(account_invoice_line, self).product_id_change_unit_price_inv(cr, uid, tax_id, price_unit, qty, address_invoice_id, product, partner_id, context=context)
-
-    def product_id_change(self, cr, uid, ids, product, uom, qty=0, name='', type='out_invoice', partner_id=False, price_unit=False, address_invoice_id=False, price_type='tax_excluded', context={}):
-        context.update({'price_type': price_type})
-        return super(account_invoice_line, self).product_id_change(cr, uid, ids, product, uom, qty, name, type, partner_id, price_unit, address_invoice_id, context=context)
+# TODO: check why ?
+#
+#    def product_id_change_unit_price_inv(self, cr, uid, tax_id, price_unit, qty, address_invoice_id, product, partner_id, context={}):
+#        if context.get('price_type', False) == 'tax_included':
+#            return {'price_unit': price_unit,'invoice_line_tax_id': tax_id}
+#        else:
+#            return super(account_invoice_line, self).product_id_change_unit_price_inv(cr, uid, tax_id, price_unit, qty, address_invoice_id, product, partner_id, context=context)
+#
+#    def product_id_change(self, cr, uid, ids, product, uom, qty=0, name='', type='out_invoice', partner_id=False, price_unit=False, address_invoice_id=False, price_type='tax_excluded', context={}):
+#        context.update({'price_type': price_type})
+#        return super(account_invoice_line, self).product_id_change(cr, uid, ids, product, uom, qty, name, type, partner_id, price_unit, address_invoice_id, context=context)
 account_invoice_line()
 
 class account_invoice_tax(osv.osv):
     _inherit = "account.invoice.tax"
 
-    def compute(self, cr, uid, invoice_id):
+    def compute(self, cr, uid, invoice_id, context={}):
+        inv = self.pool.get('account.invoice').browse(cr, uid, invoice_id)
+        line_ids = map(lambda x: x.id, inv.invoice_line)
+
+
+
         tax_grouped = {}
         tax_obj = self.pool.get('account.tax')
         cur_obj = self.pool.get('res.currency')
-        inv = self.pool.get('account.invoice').browse(cr, uid, invoice_id)
         cur = inv.currency_id
 
-        if inv.price_type=='tax_excluded':
-            return super(account_invoice_tax,self).compute(cr, uid, invoice_id)
-
         for line in inv.invoice_line:
-            for tax in tax_obj.compute_inv(cr, uid, line.invoice_line_tax_id, (line.price_unit * (1-(line.discount or 0.0)/100.0)), line.quantity, inv.address_invoice_id.id, line.product_id, inv.partner_id):
+            data = self.pool.get('account.invoice.line')._amount_line2(cr, uid, [line.id], [], [], context)[line.id]
+            for tax in data['data']:
                 val={}
                 val['invoice_id'] = inv.id
                 val['name'] = tax['name']