Commit 2cbe6fde by Will Daly

Merge pull request #7371 from edx/will/django-rest-framework-cors-csrf

Skip CSRF referer check for cross-domain requests.
parents 8324fd02 b625e8e3
"""Django Rest Framework Authentication classes for cross-domain end-points."""
from rest_framework import authentication
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
class SessionAuthenticationCrossDomainCsrf(authentication.SessionAuthentication):
"""Session authentication that skips the referer check over secure connections.
Django Rest Framework's `SessionAuthentication` class calls Django's
CSRF middleware implementation directly, which bypasses the middleware
stack.
This version of `SessionAuthentication` performs the same workaround
as `CorsCSRFMiddleware` to skip the referer check for whitelisted
domains over a secure connection. See `cors_csrf.middleware` for
more information.
Since this subclass overrides only the `enforce_csrf()` method,
it can be mixed in with other `SessionAuthentication` subclasses.
"""
def enforce_csrf(self, request):
"""Skip the referer check if the cross-domain request is allowed. """
if is_cross_domain_request_allowed(request):
with skip_cross_domain_referer_check(request):
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
else:
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
"""Helper methods for CORS and CSRF checks. """
import logging
import urlparse
import contextlib
from django.conf import settings
log = logging.getLogger(__name__)
def is_cross_domain_request_allowed(request):
"""Check whether we should allow the cross-domain request.
We allow a cross-domain request only if:
1) The request is made securely and the referer has "https://" as the protocol.
2) The referer domain has been whitelisted.
Arguments:
request (HttpRequest)
Returns:
bool
"""
referer = request.META.get('HTTP_REFERER')
referer_parts = urlparse.urlparse(referer) if referer else None
referer_hostname = referer_parts.hostname if referer_parts is not None else None
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
# it should never be enabled in production.
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
if not request.is_secure():
log.debug(
u"Request is not secure, so we cannot send the CSRF token. "
u"For testing purposes, you can disable this check by setting "
u"`CORS_ALLOW_INSECURE` to True in the settings"
)
return False
if not referer:
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
return False
if not referer_parts.scheme == 'https':
log.debug(u"Referer '%s' must have the scheme 'https'")
return False
domain_is_whitelisted = (
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
)
if not domain_is_whitelisted:
if referer_hostname is None:
# If no referer is specified, we can't check if it's a cross-domain
# request or not.
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
elif referer_hostname != request.get_host():
log.warning(
(
u"Domain '%s' is not on the cross domain whitelist. "
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
), referer_hostname
)
else:
log.debug(
(
u"Domain '%s' is the same as the hostname in the request, "
u"so we are not going to treat it as a cross-domain request."
), referer_hostname
)
return False
return True
@contextlib.contextmanager
def skip_cross_domain_referer_check(request):
"""Skip the cross-domain CSRF referer check.
Django's CSRF middleware performs the referer check
only when the request is made over a secure connection.
To skip the check, we patch `request.is_secure()` to
False.
"""
is_secure_default = request.is_secure
request.is_secure = lambda: False
try:
yield
finally:
request.is_secure = is_secure_default
...@@ -43,80 +43,14 @@ CSRF cookie. ...@@ -43,80 +43,14 @@ CSRF cookie.
""" """
import logging import logging
import urlparse
from django.conf import settings from django.conf import settings
from django.middleware.csrf import CsrfViewMiddleware from django.middleware.csrf import CsrfViewMiddleware
from django.core.exceptions import MiddlewareNotUsed, ImproperlyConfigured from django.core.exceptions import MiddlewareNotUsed, ImproperlyConfigured
log = logging.getLogger(__name__) from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
def is_cross_domain_request_allowed(request):
"""Check whether we should allow the cross-domain request.
We allow a cross-domain request only if:
1) The request is made securely and the referer has "https://" as the protocol.
2) The referer domain has been whitelisted.
Arguments:
request (HttpRequest)
Returns:
bool
"""
referer = request.META.get('HTTP_REFERER')
referer_parts = urlparse.urlparse(referer) if referer else None
referer_hostname = referer_parts.hostname if referer_parts is not None else None
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
# it should never be enabled in production.
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
if not request.is_secure():
log.debug(
u"Request is not secure, so we cannot send the CSRF token. "
u"For testing purposes, you can disable this check by setting "
u"`CORS_ALLOW_INSECURE` to True in the settings"
)
return False
if not referer:
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
return False
if not referer_parts.scheme == 'https': log = logging.getLogger(__name__)
log.debug(u"Referer '%s' must have the scheme 'https'")
return False
domain_is_whitelisted = (
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
)
if not domain_is_whitelisted:
if referer_hostname is None:
# If no referer is specified, we can't check if it's a cross-domain
# request or not.
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
elif referer_hostname != request.get_host():
log.warning(
(
u"Domain '%s' is not on the cross domain whitelist. "
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
), referer_hostname
)
else:
log.debug(
(
u"Domain '%s' is the same as the hostname in the request, "
u"so we are not going to treat it as a cross-domain request."
), referer_hostname
)
return False
return True
class CorsCSRFMiddleware(CsrfViewMiddleware): class CorsCSRFMiddleware(CsrfViewMiddleware):
...@@ -134,18 +68,8 @@ class CorsCSRFMiddleware(CsrfViewMiddleware): ...@@ -134,18 +68,8 @@ class CorsCSRFMiddleware(CsrfViewMiddleware):
log.debug("Could not disable CSRF middleware referer check for cross-domain request.") log.debug("Could not disable CSRF middleware referer check for cross-domain request.")
return return
is_secure_default = request.is_secure with skip_cross_domain_referer_check(request):
return super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
def is_secure_patched():
"""
Avoid triggering the additional CSRF middleware checks on the referrer
"""
return False
request.is_secure = is_secure_patched
res = super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
request.is_secure = is_secure_default
return res
class CsrfCrossDomainCookieMiddleware(object): class CsrfCrossDomainCookieMiddleware(object):
......
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication."""
from mock import patch
from django.test import TestCase
from django.test.utils import override_settings
from django.test.client import RequestFactory
from django.conf import settings
from rest_framework.exceptions import AuthenticationFailed
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
class CrossDomainAuthTest(TestCase):
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication. """
URL = "/dummy_url"
REFERER = "https://www.edx.org"
CSRF_TOKEN = 'abcd1234'
def setUp(self):
super(CrossDomainAuthTest, self).setUp()
self.auth = SessionAuthenticationCrossDomainCsrf()
def test_perform_csrf_referer_check(self):
request = self._fake_request()
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'):
self.auth.enforce_csrf(request)
@patch.dict(settings.FEATURES, {
'ENABLE_CORS_HEADERS': True,
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
})
@override_settings(
CORS_ORIGIN_WHITELIST=["www.edx.org"],
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
)
def test_skip_csrf_referer_check(self):
request = self._fake_request()
result = self.auth.enforce_csrf(request)
self.assertIs(result, None)
self.assertTrue(request.is_secure())
def _fake_request(self):
"""Construct a fake request with a referer and CSRF token over a secure connection. """
factory = RequestFactory()
factory.cookies[settings.CSRF_COOKIE_NAME] = self.CSRF_TOKEN
request = factory.post(
self.URL,
HTTP_REFERER=self.REFERER,
HTTP_X_CSRFTOKEN=self.CSRF_TOKEN
)
request.is_secure = lambda: True
return request
...@@ -6,6 +6,8 @@ import json ...@@ -6,6 +6,8 @@ import json
import unittest import unittest
from mock import patch from mock import patch
from django.test import Client
from django.core.handlers.wsgi import WSGIRequest
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from rest_framework import status from rest_framework import status
...@@ -365,3 +367,81 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -365,3 +367,81 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase):
url = reverse('courseenrollments') url = reverse('courseenrollments')
resp = self.client.get(url) resp = self.client.get(url)
return json.loads(resp.content) return json.loads(resp.content)
def cross_domain_config(func):
"""Decorator for configuring a cross-domain request. """
feature_flag_decorator = patch.dict(settings.FEATURES, {
'ENABLE_CORS_HEADERS': True,
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
})
settings_decorator = override_settings(
CORS_ORIGIN_WHITELIST=["www.edx.org"],
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
)
is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)
return feature_flag_decorator(
settings_decorator(
is_secure_decorator(func)
)
)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EnrollmentCrossDomainTest(ModuleStoreTestCase):
"""Test cross-domain calls to the enrollment end-points. """
USERNAME = "Bob"
EMAIL = "bob@example.com"
PASSWORD = "edx"
REFERER = "https://www.edx.org"
def setUp(self):
""" Create a course and user, then log in. """
super(EnrollmentCrossDomainTest, self).setUp()
self.course = CourseFactory.create()
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
self.client = Client(enforce_csrf_checks=True)
self.client.login(username=self.USERNAME, password=self.PASSWORD)
@cross_domain_config
def test_cross_domain_change_enrollment(self, *args): # pylint: disable=unused-argument
csrf_cookie = self._get_csrf_cookie()
resp = self._cross_domain_post(csrf_cookie)
# Expect that the request gets through successfully,
# passing the CSRF checks (including the referer check).
self.assertEqual(resp.status_code, 200)
@cross_domain_config
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
resp = self._cross_domain_post('invalid_csrf_token')
self.assertEqual(resp.status_code, 401)
def _get_csrf_cookie(self):
"""Retrieve the cross-domain CSRF cookie. """
url = reverse('courseenrollment', kwargs={
'course_id': unicode(self.course.id)
})
resp = self.client.get(url, HTTP_REFERER=self.REFERER)
self.assertEqual(resp.status_code, 200)
self.assertIn('prod-edx-csrftoken', resp.cookies) # pylint: disable=no-member
return resp.cookies['prod-edx-csrftoken'].value # pylint: disable=no-member
def _cross_domain_post(self, csrf_cookie):
"""Perform a cross-domain POST request. """
url = reverse('courseenrollments')
params = json.dumps({
'course_details': {
'course_id': unicode(self.course.id),
},
'user': self.user.username
})
return self.client.post(
url, params, content_type='application/json',
HTTP_REFERER=self.REFERER,
HTTP_X_CSRFTOKEN=csrf_cookie
)
...@@ -14,6 +14,7 @@ from rest_framework.throttling import UserRateThrottle ...@@ -14,6 +14,7 @@ from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView from rest_framework.views import APIView
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from embargo import api as embargo_api from embargo import api as embargo_api
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
from cors_csrf.decorators import ensure_csrf_cookie_cross_domain from cors_csrf.decorators import ensure_csrf_cookie_cross_domain
from util.authentication import SessionAuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser from util.authentication import SessionAuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser
from util.disable_rate_limit import can_disable_rate_limit from util.disable_rate_limit import can_disable_rate_limit
...@@ -24,6 +25,11 @@ from enrollment.errors import ( ...@@ -24,6 +25,11 @@ from enrollment.errors import (
) )
class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf):
"""Session authentication that allows inactive users and cross-domain requests. """
pass
class EnrollmentUserThrottle(UserRateThrottle): class EnrollmentUserThrottle(UserRateThrottle):
"""Limit the number of requests users can make to the enrollment API.""" """Limit the number of requests users can make to the enrollment API."""
# TODO Limit significantly after performance testing. # pylint: disable=fixme # TODO Limit significantly after performance testing. # pylint: disable=fixme
...@@ -267,7 +273,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): ...@@ -267,7 +273,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
* user: The ID of the user. * user: The ID of the user.
""" """
authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser authentication_classes = OAuth2AuthenticationAllowInactiveUser, EnrollmentCrossDomainSessionAuth
permission_classes = ApiKeyHeaderPermissionIsAuthenticated, permission_classes = ApiKeyHeaderPermissionIsAuthenticated,
throttle_classes = EnrollmentUserThrottle, throttle_classes = EnrollmentUserThrottle,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment