[FIX] auth_openid: use set_cookie_and_redirect + handle errors correctly
[odoo/odoo.git] / addons / auth_openid / controllers / main.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    OpenERP, Open Source Management Solution
5 #    Copyright (C) 2010-2012 OpenERP s.a. (<http://openerp.com>).
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU Affero General Public License as
9 #    published by the Free Software Foundation, either version 3 of the
10 #    License, or (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU Affero General Public License for more details.
16 #
17 #    You should have received a copy of the GNU Affero General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22 import logging
23 import os
24 import tempfile
25 import urllib
26
27 import werkzeug.urls
28 import werkzeug.exceptions
29
30 from openerp.modules.registry import RegistryManager
31 from openerp.addons.web.controllers.main import login_and_redirect, set_cookie_and_redirect
32 try:
33     import openerp.addons.web.common.http as openerpweb
34 except ImportError:
35     import web.common.http as openerpweb    # noqa
36
37 from openid import oidutil
38 from openid.store import filestore
39 from openid.consumer import consumer
40 from openid.cryptutil import randomString
41 from openid.extensions import ax, sreg
42
43 from .. import utils
44
45 _logger = logging.getLogger(__name__)
46 oidutil.log = _logger.debug
47
48 _storedir = os.path.join(tempfile.gettempdir(), 'openerp-auth_openid-store')
49
50 class GoogleAppsAwareConsumer(consumer.GenericConsumer):
51     def complete(self, message, endpoint, return_to):
52         if message.getOpenIDNamespace() == consumer.OPENID2_NS:
53             server_url = message.getArg(consumer.OPENID2_NS, 'op_endpoint', '')
54             if server_url.startswith('https://www.google.com/a/'):
55                 assoc_handle = message.getArg(consumer.OPENID_NS, 'assoc_handle')
56                 assoc = self.store.getAssociation(server_url, assoc_handle)
57                 if assoc:
58                     # update fields
59                     for attr in ['claimed_id', 'identity']:
60                         value = message.getArg(consumer.OPENID2_NS, attr, '')
61                         value = 'https://www.google.com/accounts/o8/user-xrds?uri=%s' % urllib.quote_plus(value)
62                         message.setArg(consumer.OPENID2_NS, attr, value)
63
64                     # now, resign the message
65                     message.delArg(consumer.OPENID2_NS, 'sig')
66                     message.delArg(consumer.OPENID2_NS, 'signed')
67                     message = assoc.signMessage(message)
68
69         return super(GoogleAppsAwareConsumer, self).complete(message, endpoint, return_to)
70
71
72 class OpenIDController(openerpweb.Controller):
73     _cp_path = '/auth_openid/login'
74
75     _store = filestore.FileOpenIDStore(_storedir)
76
77     _REQUIRED_ATTRIBUTES = ['email']
78     _OPTIONAL_ATTRIBUTES = 'nickname fullname postcode country language timezone'.split()
79
80     def _add_extensions(self, request):
81         """Add extensions to the request"""
82
83         sreg_request = sreg.SRegRequest(required=self._REQUIRED_ATTRIBUTES,
84                                         optional=self._OPTIONAL_ATTRIBUTES)
85         request.addExtension(sreg_request)
86
87         ax_request = ax.FetchRequest()
88         for alias in self._REQUIRED_ATTRIBUTES:
89             uri = utils.SREG2AX[alias]
90             ax_request.add(ax.AttrInfo(uri, required=True, alias=alias))
91         for alias in self._OPTIONAL_ATTRIBUTES:
92             uri = utils.SREG2AX[alias]
93             ax_request.add(ax.AttrInfo(uri, required=False, alias=alias))
94
95         request.addExtension(ax_request)
96
97     def _get_attributes_from_success_response(self, success_response):
98         attrs = {}
99
100         all_attrs = self._REQUIRED_ATTRIBUTES + self._OPTIONAL_ATTRIBUTES
101
102         sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_response)
103         if sreg_resp:
104             for attr in all_attrs:
105                 value = sreg_resp.get(attr)
106                 if value is not None:
107                     attrs[attr] = value
108
109         ax_resp = ax.FetchResponse.fromSuccessResponse(success_response)
110         if ax_resp:
111             for attr in all_attrs:
112                 value = ax_resp.getSingle(utils.SREG2AX[attr])
113                 if value is not None:
114                     attrs[attr] = value
115         return attrs
116
117     def _get_realm(self, req):
118         return req.httprequest.host_url
119
120     @openerpweb.httprequest
121     def verify_direct(self, req, db, url):
122         result = self._verify(req, db, url)
123         if 'error' in result:
124             return werkzeug.exceptions.BadRequest(result['error'])
125         if result['action'] == 'redirect':
126             return werkzeug.utils.redirect(result['value'])
127         return result['value']
128
129     @openerpweb.jsonrequest
130     def verify(self, req, db, url):
131         return self._verify(req, db, url)
132
133     def _verify(self, req, db, url):
134         redirect_to = werkzeug.urls.Href(req.httprequest.host_url + 'auth_openid/login/process')(session_id=req.session_id)
135         realm = self._get_realm(req)
136
137         session = dict(dbname=db, openid_url=url)       # TODO add origin page ?
138         oidconsumer = consumer.Consumer(session, self._store)
139
140         try:
141             request = oidconsumer.begin(url)
142         except consumer.DiscoveryFailure, exc:
143             fetch_error_string = 'Error in discovery: %s' % (str(exc[0]),)
144             return {'error': fetch_error_string, 'title': 'OpenID Error'}
145
146         if request is None:
147             return {'error': 'No OpenID services found', 'title': 'OpenID Error'}
148
149         req.session.openid_session = session
150         self._add_extensions(request)
151
152         if request.shouldSendRedirect():
153             redirect_url = request.redirectURL(realm, redirect_to)
154             return {'action': 'redirect', 'value': redirect_url, 'session_id': req.session_id}
155         else:
156             form_html = request.htmlMarkup(realm, redirect_to)
157             return {'action': 'post', 'value': form_html, 'session_id': req.session_id}
158
159     @openerpweb.httprequest
160     def process(self, req, **kw):
161         session = getattr(req.session, 'openid_session', None)
162         if not session:
163             return set_cookie_and_redirect(req, '/')
164
165         oidconsumer = consumer.Consumer(session, self._store, consumer_class=GoogleAppsAwareConsumer)
166
167         query = req.httprequest.args
168         info = oidconsumer.complete(query, req.httprequest.base_url)
169         display_identifier = info.getDisplayIdentifier()
170
171         session['status'] = info.status
172
173         if info.status == consumer.SUCCESS:
174             dbname = session['dbname']
175             registry = RegistryManager.get(dbname)
176             with registry.cursor() as cr:
177                 Modules = registry.get('ir.module.module')
178
179                 installed = Modules.search_count(cr, 1, ['&', ('name', '=', 'auth_openid'), ('state', '=', 'installed')]) == 1
180                 if installed:
181
182                     Users = registry.get('res.users')
183
184                     #openid_url = info.endpoint.canonicalID or display_identifier
185                     openid_url = session['openid_url']
186
187                     attrs = self._get_attributes_from_success_response(info)
188                     attrs['openid_url'] = openid_url
189                     session['attributes'] = attrs
190                     openid_email = attrs.get('email', False)
191
192                     domain = []
193                     if openid_email:
194                         domain += ['|', ('openid_email', '=', False)]
195                     domain += [('openid_email', '=', openid_email)]
196
197                     domain += [('openid_url', '=', openid_url), ('active', '=', True)]
198
199                     ids = Users.search(cr, 1, domain)
200                     assert len(ids) < 2
201                     if ids:
202                         user_id = ids[0]
203                         login = Users.browse(cr, 1, user_id).login
204                         key = randomString(utils.KEY_LENGTH, '0123456789abcdef')
205                         Users.write(cr, 1, [user_id], {'openid_key': key})
206                         # TODO fill empty fields with the ones from sreg/ax
207                         cr.commit()
208
209                         return login_and_redirect(req, dbname, login, key)
210
211             session['message'] = 'This OpenID identifier is not associated to any active users'
212
213         elif info.status == consumer.SETUP_NEEDED:
214             session['message'] = info.setup_url
215         elif info.status == consumer.FAILURE and display_identifier:
216             fmt = "Verification of %s failed: %s"
217             session['message'] = fmt % (display_identifier, info.message)
218         else:   # FAILURE
219             # Either we don't understand the code or there is no
220             # openid_url included with the error. Give a generic
221             # failure message. The library should supply debug
222             # information in a log.
223             session['message'] = 'Verification failed.'
224
225         return set_cookie_and_redirect(req, '/#action=login&loginerror=1')
226
227     @openerpweb.jsonrequest
228     def status(self, req):
229         session = getattr(req.session, 'openid_session', {})
230         return {'status': session.get('status'), 'message': session.get('message')}
231
232
233 # vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4: