Commit 1909472a by Benjamin Dauvergne

authentication: allow all transport modes of access token in OAuth2Authentication

RFC6750 describe three transport modes for access tokens when accessing a
protected resource:
- Auhthorization header with the Bearer authentication type
- form-encoded body parameter
- URI query parameter

This patch add support for last two transport modes.
parent c4459167
...@@ -6,6 +6,7 @@ import base64 ...@@ -6,6 +6,7 @@ import base64
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
...@@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication): ...@@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):
OAuth 2 authentication backend using `django-oauth2-provider` OAuth 2 authentication backend using `django-oauth2-provider`
""" """
www_authenticate_realm = 'api' www_authenticate_realm = 'api'
allow_query_params_token = settings.DEBUG
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(OAuth2Authentication, self).__init__(*args, **kwargs) super(OAuth2Authentication, self).__init__(*args, **kwargs)
...@@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication): ...@@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
auth = get_authorization_header(request).split() auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != b'bearer': if auth and auth[0].lower() == b'bearer':
access_token = auth[1]
elif 'access_token' in request.POST:
access_token = request.POST['access_token']
elif 'access_token' in request.GET and self.allow_query_params_token:
access_token = request.GET['access_token']
else:
return None return None
if len(auth) == 1: if len(auth) == 1:
...@@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication): ...@@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):
msg = 'Invalid bearer header. Token string should not contain spaces.' msg = 'Invalid bearer header. Token string should not contain spaces.'
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(request, auth[1]) return self.authenticate_credentials(request, access_token)
def authenticate_credentials(self, request, access_token): def authenticate_credentials(self, request, access_token):
""" """
......
...@@ -3,6 +3,7 @@ from django.contrib.auth.models import User ...@@ -3,6 +3,7 @@ from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from django.utils.http import urlencode
from rest_framework import HTTP_HEADER_ENCODING from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework import permissions from rest_framework import permissions
...@@ -53,10 +54,14 @@ urlpatterns = patterns('', ...@@ -53,10 +54,14 @@ urlpatterns = patterns('',
permission_classes=[permissions.TokenHasReadWriteScope])) permission_classes=[permissions.TokenHasReadWriteScope]))
) )
class OAuth2AuthenticationDebug(OAuth2Authentication):
allow_query_params_token = True
if oauth2_provider is not None: if oauth2_provider is not None:
urlpatterns += patterns('', urlpatterns += patterns('',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])), permission_classes=[permissions.TokenHasReadWriteScope])),
) )
...@@ -546,6 +551,27 @@ class OAuth2Tests(TestCase): ...@@ -546,6 +551,27 @@ class OAuth2Tests(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
response = self.csrf_client.post('/oauth2-test/',
data={'access_token': self.access_token.token})
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
query = urlencode({'access_token': self.access_token.token})
response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_failing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
query = urlencode({'access_token': self.access_token.token})
response = self.csrf_client.get('/oauth2-test/?%s' % query)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self): def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header() auth = self._create_authorization_header()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment