Commit 2cbe6fde by Will Daly

Merge pull request #7371 from edx/will/django-rest-framework-cors-csrf

Skip CSRF referer check for cross-domain requests.
parents 8324fd02 b625e8e3
"""Django Rest Framework Authentication classes for cross-domain end-points."""
from rest_framework import authentication
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
class SessionAuthenticationCrossDomainCsrf(authentication.SessionAuthentication):
"""Session authentication that skips the referer check over secure connections.
Django Rest Framework's `SessionAuthentication` class calls Django's
CSRF middleware implementation directly, which bypasses the middleware
stack.
This version of `SessionAuthentication` performs the same workaround
as `CorsCSRFMiddleware` to skip the referer check for whitelisted
domains over a secure connection. See `cors_csrf.middleware` for
more information.
Since this subclass overrides only the `enforce_csrf()` method,
it can be mixed in with other `SessionAuthentication` subclasses.
"""
def enforce_csrf(self, request):
"""Skip the referer check if the cross-domain request is allowed. """
if is_cross_domain_request_allowed(request):
with skip_cross_domain_referer_check(request):
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
else:
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
"""Helper methods for CORS and CSRF checks. """
import logging
import urlparse
import contextlib
from django.conf import settings
log = logging.getLogger(__name__)
def is_cross_domain_request_allowed(request):
"""Check whether we should allow the cross-domain request.
We allow a cross-domain request only if:
1) The request is made securely and the referer has "https://" as the protocol.
2) The referer domain has been whitelisted.
Arguments:
request (HttpRequest)
Returns:
bool
"""
referer = request.META.get('HTTP_REFERER')
referer_parts = urlparse.urlparse(referer) if referer else None
referer_hostname = referer_parts.hostname if referer_parts is not None else None
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
# it should never be enabled in production.
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
if not request.is_secure():
log.debug(
u"Request is not secure, so we cannot send the CSRF token. "
u"For testing purposes, you can disable this check by setting "
u"`CORS_ALLOW_INSECURE` to True in the settings"
)
return False
if not referer:
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
return False
if not referer_parts.scheme == 'https':
log.debug(u"Referer '%s' must have the scheme 'https'")
return False
domain_is_whitelisted = (
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
)
if not domain_is_whitelisted:
if referer_hostname is None:
# If no referer is specified, we can't check if it's a cross-domain
# request or not.
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
elif referer_hostname != request.get_host():
log.warning(
(
u"Domain '%s' is not on the cross domain whitelist. "
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
), referer_hostname
)
else:
log.debug(
(
u"Domain '%s' is the same as the hostname in the request, "
u"so we are not going to treat it as a cross-domain request."
), referer_hostname
)
return False
return True
@contextlib.contextmanager
def skip_cross_domain_referer_check(request):
"""Skip the cross-domain CSRF referer check.
Django's CSRF middleware performs the referer check
only when the request is made over a secure connection.
To skip the check, we patch `request.is_secure()` to
False.
"""
is_secure_default = request.is_secure
request.is_secure = lambda: False
try:
yield
finally:
request.is_secure = is_secure_default
......@@ -43,80 +43,14 @@ CSRF cookie.
"""
import logging
import urlparse
from django.conf import settings
from django.middleware.csrf import CsrfViewMiddleware
from django.core.exceptions import MiddlewareNotUsed, ImproperlyConfigured
log = logging.getLogger(__name__)
def is_cross_domain_request_allowed(request):
"""Check whether we should allow the cross-domain request.
We allow a cross-domain request only if:
1) The request is made securely and the referer has "https://" as the protocol.
2) The referer domain has been whitelisted.
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
Arguments:
request (HttpRequest)
Returns:
bool
"""
referer = request.META.get('HTTP_REFERER')
referer_parts = urlparse.urlparse(referer) if referer else None
referer_hostname = referer_parts.hostname if referer_parts is not None else None
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
# it should never be enabled in production.
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
if not request.is_secure():
log.debug(
u"Request is not secure, so we cannot send the CSRF token. "
u"For testing purposes, you can disable this check by setting "
u"`CORS_ALLOW_INSECURE` to True in the settings"
)
return False
if not referer:
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
return False
if not referer_parts.scheme == 'https':
log.debug(u"Referer '%s' must have the scheme 'https'")
return False
domain_is_whitelisted = (
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
)
if not domain_is_whitelisted:
if referer_hostname is None:
# If no referer is specified, we can't check if it's a cross-domain
# request or not.
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
elif referer_hostname != request.get_host():
log.warning(
(
u"Domain '%s' is not on the cross domain whitelist. "
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
), referer_hostname
)
else:
log.debug(
(
u"Domain '%s' is the same as the hostname in the request, "
u"so we are not going to treat it as a cross-domain request."
), referer_hostname
)
return False
return True
log = logging.getLogger(__name__)
class CorsCSRFMiddleware(CsrfViewMiddleware):
......@@ -134,18 +68,8 @@ class CorsCSRFMiddleware(CsrfViewMiddleware):
log.debug("Could not disable CSRF middleware referer check for cross-domain request.")
return
is_secure_default = request.is_secure
def is_secure_patched():
"""
Avoid triggering the additional CSRF middleware checks on the referrer
"""
return False
request.is_secure = is_secure_patched
res = super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
request.is_secure = is_secure_default
return res
with skip_cross_domain_referer_check(request):
return super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
class CsrfCrossDomainCookieMiddleware(object):
......
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication."""
from mock import patch
from django.test import TestCase
from django.test.utils import override_settings
from django.test.client import RequestFactory
from django.conf import settings
from rest_framework.exceptions import AuthenticationFailed
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
class CrossDomainAuthTest(TestCase):
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication. """
URL = "/dummy_url"
REFERER = "https://www.edx.org"
CSRF_TOKEN = 'abcd1234'
def setUp(self):
super(CrossDomainAuthTest, self).setUp()
self.auth = SessionAuthenticationCrossDomainCsrf()
def test_perform_csrf_referer_check(self):
request = self._fake_request()
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'):
self.auth.enforce_csrf(request)
@patch.dict(settings.FEATURES, {
'ENABLE_CORS_HEADERS': True,
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
})
@override_settings(
CORS_ORIGIN_WHITELIST=["www.edx.org"],
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
)
def test_skip_csrf_referer_check(self):
request = self._fake_request()
result = self.auth.enforce_csrf(request)
self.assertIs(result, None)
self.assertTrue(request.is_secure())
def _fake_request(self):
"""Construct a fake request with a referer and CSRF token over a secure connection. """
factory = RequestFactory()
factory.cookies[settings.CSRF_COOKIE_NAME] = self.CSRF_TOKEN
request = factory.post(
self.URL,
HTTP_REFERER=self.REFERER,
HTTP_X_CSRFTOKEN=self.CSRF_TOKEN
)
request.is_secure = lambda: True
return request
......@@ -6,6 +6,8 @@ import json
import unittest
from mock import patch
from django.test import Client
from django.core.handlers.wsgi import WSGIRequest
from django.core.urlresolvers import reverse
from rest_framework.test import APITestCase
from rest_framework import status
......@@ -365,3 +367,81 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase):
url = reverse('courseenrollments')
resp = self.client.get(url)
return json.loads(resp.content)
def cross_domain_config(func):
"""Decorator for configuring a cross-domain request. """
feature_flag_decorator = patch.dict(settings.FEATURES, {
'ENABLE_CORS_HEADERS': True,
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
})
settings_decorator = override_settings(
CORS_ORIGIN_WHITELIST=["www.edx.org"],
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
)
is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)
return feature_flag_decorator(
settings_decorator(
is_secure_decorator(func)
)
)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EnrollmentCrossDomainTest(ModuleStoreTestCase):
"""Test cross-domain calls to the enrollment end-points. """
USERNAME = "Bob"
EMAIL = "bob@example.com"
PASSWORD = "edx"
REFERER = "https://www.edx.org"
def setUp(self):
""" Create a course and user, then log in. """
super(EnrollmentCrossDomainTest, self).setUp()
self.course = CourseFactory.create()
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
self.client = Client(enforce_csrf_checks=True)
self.client.login(username=self.USERNAME, password=self.PASSWORD)
@cross_domain_config
def test_cross_domain_change_enrollment(self, *args): # pylint: disable=unused-argument
csrf_cookie = self._get_csrf_cookie()
resp = self._cross_domain_post(csrf_cookie)
# Expect that the request gets through successfully,
# passing the CSRF checks (including the referer check).
self.assertEqual(resp.status_code, 200)
@cross_domain_config
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
resp = self._cross_domain_post('invalid_csrf_token')
self.assertEqual(resp.status_code, 401)
def _get_csrf_cookie(self):
"""Retrieve the cross-domain CSRF cookie. """
url = reverse('courseenrollment', kwargs={
'course_id': unicode(self.course.id)
})
resp = self.client.get(url, HTTP_REFERER=self.REFERER)
self.assertEqual(resp.status_code, 200)
self.assertIn('prod-edx-csrftoken', resp.cookies) # pylint: disable=no-member
return resp.cookies['prod-edx-csrftoken'].value # pylint: disable=no-member
def _cross_domain_post(self, csrf_cookie):
"""Perform a cross-domain POST request. """
url = reverse('courseenrollments')
params = json.dumps({
'course_details': {
'course_id': unicode(self.course.id),
},
'user': self.user.username
})
return self.client.post(
url, params, content_type='application/json',
HTTP_REFERER=self.REFERER,
HTTP_X_CSRFTOKEN=csrf_cookie
)
......@@ -14,6 +14,7 @@ from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
from opaque_keys.edx.keys import CourseKey
from embargo import api as embargo_api
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
from cors_csrf.decorators import ensure_csrf_cookie_cross_domain
from util.authentication import SessionAuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser
from util.disable_rate_limit import can_disable_rate_limit
......@@ -24,6 +25,11 @@ from enrollment.errors import (
)
class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf):
"""Session authentication that allows inactive users and cross-domain requests. """
pass
class EnrollmentUserThrottle(UserRateThrottle):
"""Limit the number of requests users can make to the enrollment API."""
# TODO Limit significantly after performance testing. # pylint: disable=fixme
......@@ -267,7 +273,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
* user: The ID of the user.
"""
authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser
authentication_classes = OAuth2AuthenticationAllowInactiveUser, EnrollmentCrossDomainSessionAuth
permission_classes = ApiKeyHeaderPermissionIsAuthenticated,
throttle_classes = EnrollmentUserThrottle,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment