Commit 873a142a by Tom Christie

Implementing 401 vs 403 responses

parent 957700ec
...@@ -34,27 +34,33 @@ class BasicAuthentication(BaseAuthentication): ...@@ -34,27 +34,33 @@ class BasicAuthentication(BaseAuthentication):
""" """
HTTP Basic authentication against username/password. HTTP Basic authentication against username/password.
""" """
www_authenticate_realm = 'api'
def authenticate(self, request): def authenticate(self, request):
""" """
Returns a `User` if a correct username and password have been supplied Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`. using HTTP Basic authentication. Otherwise returns `None`.
""" """
if 'HTTP_AUTHORIZATION' in request.META: auth = request.META.get('HTTP_AUTHORIZATION', '').split()
auth = request.META['HTTP_AUTHORIZATION'].split()
if len(auth) == 2 and auth[0].lower() == "basic": if not auth or auth[0].lower() != "basic":
try: return None
auth_parts = base64.b64decode(auth[1]).partition(':')
except TypeError: if len(auth) != 2:
return None raise exceptions.AuthenticationFailed('Invalid basic header')
try: try:
userid = smart_unicode(auth_parts[0]) auth_parts = base64.b64decode(auth[1]).partition(':')
password = smart_unicode(auth_parts[2]) except TypeError:
except DjangoUnicodeDecodeError: raise exceptions.AuthenticationFailed('Invalid basic header')
return None
try:
return self.authenticate_credentials(userid, password) userid = smart_unicode(auth_parts[0])
password = smart_unicode(auth_parts[2])
except DjangoUnicodeDecodeError:
raise exceptions.AuthenticationFailed('Invalid basic header')
return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password): def authenticate_credentials(self, userid, password):
""" """
...@@ -63,6 +69,10 @@ class BasicAuthentication(BaseAuthentication): ...@@ -63,6 +69,10 @@ class BasicAuthentication(BaseAuthentication):
user = authenticate(username=userid, password=password) user = authenticate(username=userid, password=password)
if user is not None and user.is_active: if user is not None and user.is_active:
return (user, None) return (user, None)
raise exceptions.AuthenticationFailed('Invalid username/password')
def authenticate_header(self):
return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication): class SessionAuthentication(BaseAuthentication):
...@@ -82,7 +92,7 @@ class SessionAuthentication(BaseAuthentication): ...@@ -82,7 +92,7 @@ class SessionAuthentication(BaseAuthentication):
# Unauthenticated, CSRF validation not required # Unauthenticated, CSRF validation not required
if not user or not user.is_active: if not user or not user.is_active:
return return None
# Enforce CSRF validation for session based authentication. # Enforce CSRF validation for session based authentication.
class CSRFCheck(CsrfViewMiddleware): class CSRFCheck(CsrfViewMiddleware):
...@@ -93,7 +103,7 @@ class SessionAuthentication(BaseAuthentication): ...@@ -93,7 +103,7 @@ class SessionAuthentication(BaseAuthentication):
reason = CSRFCheck().process_view(http_request, None, (), {}) reason = CSRFCheck().process_view(http_request, None, (), {})
if reason: if reason:
# CSRF failed, bail with explicit error message # CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
# CSRF passed with authenticated user # CSRF passed with authenticated user
return (user, None) return (user, None)
...@@ -120,14 +130,26 @@ class TokenAuthentication(BaseAuthentication): ...@@ -120,14 +130,26 @@ class TokenAuthentication(BaseAuthentication):
def authenticate(self, request): def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '').split() auth = request.META.get('HTTP_AUTHORIZATION', '').split()
if len(auth) == 2 and auth[0].lower() == "token": if not auth or auth[0].lower() != "token":
key = auth[1] return None
try:
token = self.model.objects.get(key=key) if len(auth) != 2:
except self.model.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token header')
return None
return self.authenticate_credentials(auth[1])
def authenticate_credentials(self, key):
try:
token = self.model.objects.get(key=key)
except self.model.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
if token.user.is_active:
return (token.user, token)
raise exceptions.AuthenticationFailed('User inactive or deleted')
def authenticate_header(self):
return 'Token'
if token.user.is_active:
return (token.user, token)
# TODO: OAuthAuthentication # TODO: OAuthAuthentication
...@@ -86,6 +86,7 @@ class Request(object): ...@@ -86,6 +86,7 @@ class Request(object):
self._method = Empty self._method = Empty
self._content_type = Empty self._content_type = Empty
self._stream = Empty self._stream = Empty
self._authenticator = None
if self.parser_context is None: if self.parser_context is None:
self.parser_context = {} self.parser_context = {}
...@@ -166,7 +167,7 @@ class Request(object): ...@@ -166,7 +167,7 @@ class Request(object):
by the authentication classes provided to the request. by the authentication classes provided to the request.
""" """
if not hasattr(self, '_user'): if not hasattr(self, '_user'):
self._user, self._auth = self._authenticate() self._authenticator, self._user, self._auth = self._authenticate()
return self._user return self._user
@property @property
...@@ -176,9 +177,17 @@ class Request(object): ...@@ -176,9 +177,17 @@ class Request(object):
request, such as an authentication token. request, such as an authentication token.
""" """
if not hasattr(self, '_auth'): if not hasattr(self, '_auth'):
self._user, self._auth = self._authenticate() self._authenticator, self._user, self._auth = self._authenticate()
return self._auth return self._auth
@property
def successful_authenticator(self):
"""
Return the instance of the authentication instance class that was used
to authenticate the request, or `None`.
"""
return self._authenticator
def _load_data_and_files(self): def _load_data_and_files(self):
""" """
Parses the request content into self.DATA and self.FILES. Parses the request content into self.DATA and self.FILES.
...@@ -282,21 +291,23 @@ class Request(object): ...@@ -282,21 +291,23 @@ class Request(object):
def _authenticate(self): def _authenticate(self):
""" """
Attempt to authenticate the request using each authentication instance in turn. Attempt to authenticate the request using each authentication instance
Returns a two-tuple of (user, authtoken). in turn.
Returns a three-tuple of (authenticator, user, authtoken).
""" """
for authenticator in self.authenticators: for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self) user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None: if not user_auth_tuple is None:
return user_auth_tuple user, auth = user_auth_tuple
return (authenticator, user, auth)
return self._not_authenticated() return self._not_authenticated()
def _not_authenticated(self): def _not_authenticated(self):
""" """
Return a two-tuple of (user, authtoken), representing an Return a three-tuple of (authenticator, user, authtoken), representing
unauthenticated request. an unauthenticated request.
By default this will be (AnonymousUser, None). By default this will be (None, AnonymousUser, None).
""" """
if api_settings.UNAUTHENTICATED_USER: if api_settings.UNAUTHENTICATED_USER:
user = api_settings.UNAUTHENTICATED_USER() user = api_settings.UNAUTHENTICATED_USER()
...@@ -308,7 +319,7 @@ class Request(object): ...@@ -308,7 +319,7 @@ class Request(object):
else: else:
auth = None auth = None
return (user, auth) return (None, user, auth)
def __getattr__(self, attr): def __getattr__(self, attr):
""" """
......
...@@ -148,6 +148,8 @@ class APIView(View): ...@@ -148,6 +148,8 @@ class APIView(View):
""" """
If request is not permitted, determine what kind of exception to raise. If request is not permitted, determine what kind of exception to raise.
""" """
if self.request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied() raise exceptions.PermissionDenied()
def throttled(self, request, wait): def throttled(self, request, wait):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment