Commit b625e8e3 by Will Daly

Skip CSRF referer check for cross-domain requests.

This commit extends the workaround in `cors_csrf` middleware
to Django Rest Framework's SessionAuthentication, which
calls Django's CSRF middleware directly.

The workaround checks the cross domain whitelist and
skips the CSRF referer check for domains on the whitelist.
parent 28ad1388
"""Django Rest Framework Authentication classes for cross-domain end-points."""
from rest_framework import authentication
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
class SessionAuthenticationCrossDomainCsrf(authentication.SessionAuthentication):
"""Session authentication that skips the referer check over secure connections.
Django Rest Framework's `SessionAuthentication` class calls Django's
CSRF middleware implementation directly, which bypasses the middleware
stack.
This version of `SessionAuthentication` performs the same workaround
as `CorsCSRFMiddleware` to skip the referer check for whitelisted
domains over a secure connection. See `cors_csrf.middleware` for
more information.
Since this subclass overrides only the `enforce_csrf()` method,
it can be mixed in with other `SessionAuthentication` subclasses.
"""
def enforce_csrf(self, request):
"""Skip the referer check if the cross-domain request is allowed. """
if is_cross_domain_request_allowed(request):
with skip_cross_domain_referer_check(request):
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
else:
return super(SessionAuthenticationCrossDomainCsrf, self).enforce_csrf(request)
"""Helper methods for CORS and CSRF checks. """
import logging
import urlparse
import contextlib
from django.conf import settings
log = logging.getLogger(__name__)
def is_cross_domain_request_allowed(request):
"""Check whether we should allow the cross-domain request.
We allow a cross-domain request only if:
1) The request is made securely and the referer has "https://" as the protocol.
2) The referer domain has been whitelisted.
Arguments:
request (HttpRequest)
Returns:
bool
"""
referer = request.META.get('HTTP_REFERER')
referer_parts = urlparse.urlparse(referer) if referer else None
referer_hostname = referer_parts.hostname if referer_parts is not None else None
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
# it should never be enabled in production.
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
if not request.is_secure():
log.debug(
u"Request is not secure, so we cannot send the CSRF token. "
u"For testing purposes, you can disable this check by setting "
u"`CORS_ALLOW_INSECURE` to True in the settings"
)
return False
if not referer:
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
return False
if not referer_parts.scheme == 'https':
log.debug(u"Referer '%s' must have the scheme 'https'")
return False
domain_is_whitelisted = (
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
)
if not domain_is_whitelisted:
if referer_hostname is None:
# If no referer is specified, we can't check if it's a cross-domain
# request or not.
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
elif referer_hostname != request.get_host():
log.warning(
(
u"Domain '%s' is not on the cross domain whitelist. "
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
), referer_hostname
)
else:
log.debug(
(
u"Domain '%s' is the same as the hostname in the request, "
u"so we are not going to treat it as a cross-domain request."
), referer_hostname
)
return False
return True
@contextlib.contextmanager
def skip_cross_domain_referer_check(request):
"""Skip the cross-domain CSRF referer check.
Django's CSRF middleware performs the referer check
only when the request is made over a secure connection.
To skip the check, we patch `request.is_secure()` to
False.
"""
is_secure_default = request.is_secure
request.is_secure = lambda: False
try:
yield
finally:
request.is_secure = is_secure_default
......@@ -43,80 +43,14 @@ CSRF cookie.
"""
import logging
import urlparse
from django.conf import settings
from django.middleware.csrf import CsrfViewMiddleware
from django.core.exceptions import MiddlewareNotUsed, ImproperlyConfigured
log = logging.getLogger(__name__)
def is_cross_domain_request_allowed(request):
"""Check whether we should allow the cross-domain request.
We allow a cross-domain request only if:
1) The request is made securely and the referer has "https://" as the protocol.
2) The referer domain has been whitelisted.
from cors_csrf.helpers import is_cross_domain_request_allowed, skip_cross_domain_referer_check
Arguments:
request (HttpRequest)
Returns:
bool
"""
referer = request.META.get('HTTP_REFERER')
referer_parts = urlparse.urlparse(referer) if referer else None
referer_hostname = referer_parts.hostname if referer_parts is not None else None
# Use CORS_ALLOW_INSECURE *only* for development and testing environments;
# it should never be enabled in production.
if not getattr(settings, 'CORS_ALLOW_INSECURE', False):
if not request.is_secure():
log.debug(
u"Request is not secure, so we cannot send the CSRF token. "
u"For testing purposes, you can disable this check by setting "
u"`CORS_ALLOW_INSECURE` to True in the settings"
)
return False
if not referer:
log.debug(u"No referer provided over a secure connection, so we cannot check the protocol.")
return False
if not referer_parts.scheme == 'https':
log.debug(u"Referer '%s' must have the scheme 'https'")
return False
domain_is_whitelisted = (
getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False) or
referer_hostname in getattr(settings, 'CORS_ORIGIN_WHITELIST', [])
)
if not domain_is_whitelisted:
if referer_hostname is None:
# If no referer is specified, we can't check if it's a cross-domain
# request or not.
log.debug(u"Referrer hostname is `None`, so it is not on the whitelist.")
elif referer_hostname != request.get_host():
log.warning(
(
u"Domain '%s' is not on the cross domain whitelist. "
u"Add the domain to `CORS_ORIGIN_WHITELIST` or set "
u"`CORS_ORIGIN_ALLOW_ALL` to True in the settings."
), referer_hostname
)
else:
log.debug(
(
u"Domain '%s' is the same as the hostname in the request, "
u"so we are not going to treat it as a cross-domain request."
), referer_hostname
)
return False
return True
log = logging.getLogger(__name__)
class CorsCSRFMiddleware(CsrfViewMiddleware):
......@@ -134,18 +68,8 @@ class CorsCSRFMiddleware(CsrfViewMiddleware):
log.debug("Could not disable CSRF middleware referer check for cross-domain request.")
return
is_secure_default = request.is_secure
def is_secure_patched():
"""
Avoid triggering the additional CSRF middleware checks on the referrer
"""
return False
request.is_secure = is_secure_patched
res = super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
request.is_secure = is_secure_default
return res
with skip_cross_domain_referer_check(request):
return super(CorsCSRFMiddleware, self).process_view(request, callback, callback_args, callback_kwargs)
class CsrfCrossDomainCookieMiddleware(object):
......
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication."""
from mock import patch
from django.test import TestCase
from django.test.utils import override_settings
from django.test.client import RequestFactory
from django.conf import settings
from rest_framework.exceptions import AuthenticationFailed
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
class CrossDomainAuthTest(TestCase):
"""Tests for the CORS CSRF version of Django Rest Framework's SessionAuthentication. """
URL = "/dummy_url"
REFERER = "https://www.edx.org"
CSRF_TOKEN = 'abcd1234'
def setUp(self):
super(CrossDomainAuthTest, self).setUp()
self.auth = SessionAuthenticationCrossDomainCsrf()
def test_perform_csrf_referer_check(self):
request = self._fake_request()
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'):
self.auth.enforce_csrf(request)
@patch.dict(settings.FEATURES, {
'ENABLE_CORS_HEADERS': True,
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
})
@override_settings(
CORS_ORIGIN_WHITELIST=["www.edx.org"],
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
)
def test_skip_csrf_referer_check(self):
request = self._fake_request()
result = self.auth.enforce_csrf(request)
self.assertIs(result, None)
self.assertTrue(request.is_secure())
def _fake_request(self):
"""Construct a fake request with a referer and CSRF token over a secure connection. """
factory = RequestFactory()
factory.cookies[settings.CSRF_COOKIE_NAME] = self.CSRF_TOKEN
request = factory.post(
self.URL,
HTTP_REFERER=self.REFERER,
HTTP_X_CSRFTOKEN=self.CSRF_TOKEN
)
request.is_secure = lambda: True
return request
......@@ -6,6 +6,8 @@ import json
import unittest
from mock import patch
from django.test import Client
from django.core.handlers.wsgi import WSGIRequest
from django.core.urlresolvers import reverse
from rest_framework.test import APITestCase
from rest_framework import status
......@@ -365,3 +367,81 @@ class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase):
url = reverse('courseenrollments')
resp = self.client.get(url)
return json.loads(resp.content)
def cross_domain_config(func):
"""Decorator for configuring a cross-domain request. """
feature_flag_decorator = patch.dict(settings.FEATURES, {
'ENABLE_CORS_HEADERS': True,
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
})
settings_decorator = override_settings(
CORS_ORIGIN_WHITELIST=["www.edx.org"],
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
)
is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)
return feature_flag_decorator(
settings_decorator(
is_secure_decorator(func)
)
)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EnrollmentCrossDomainTest(ModuleStoreTestCase):
"""Test cross-domain calls to the enrollment end-points. """
USERNAME = "Bob"
EMAIL = "bob@example.com"
PASSWORD = "edx"
REFERER = "https://www.edx.org"
def setUp(self):
""" Create a course and user, then log in. """
super(EnrollmentCrossDomainTest, self).setUp()
self.course = CourseFactory.create()
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
self.client = Client(enforce_csrf_checks=True)
self.client.login(username=self.USERNAME, password=self.PASSWORD)
@cross_domain_config
def test_cross_domain_change_enrollment(self, *args): # pylint: disable=unused-argument
csrf_cookie = self._get_csrf_cookie()
resp = self._cross_domain_post(csrf_cookie)
# Expect that the request gets through successfully,
# passing the CSRF checks (including the referer check).
self.assertEqual(resp.status_code, 200)
@cross_domain_config
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
resp = self._cross_domain_post('invalid_csrf_token')
self.assertEqual(resp.status_code, 401)
def _get_csrf_cookie(self):
"""Retrieve the cross-domain CSRF cookie. """
url = reverse('courseenrollment', kwargs={
'course_id': unicode(self.course.id)
})
resp = self.client.get(url, HTTP_REFERER=self.REFERER)
self.assertEqual(resp.status_code, 200)
self.assertIn('prod-edx-csrftoken', resp.cookies) # pylint: disable=no-member
return resp.cookies['prod-edx-csrftoken'].value # pylint: disable=no-member
def _cross_domain_post(self, csrf_cookie):
"""Perform a cross-domain POST request. """
url = reverse('courseenrollments')
params = json.dumps({
'course_details': {
'course_id': unicode(self.course.id),
},
'user': self.user.username
})
return self.client.post(
url, params, content_type='application/json',
HTTP_REFERER=self.REFERER,
HTTP_X_CSRFTOKEN=csrf_cookie
)
......@@ -14,6 +14,7 @@ from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
from opaque_keys.edx.keys import CourseKey
from embargo import api as embargo_api
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
from cors_csrf.decorators import ensure_csrf_cookie_cross_domain
from util.authentication import SessionAuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser
from util.disable_rate_limit import can_disable_rate_limit
......@@ -24,6 +25,11 @@ from enrollment.errors import (
)
class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf):
"""Session authentication that allows inactive users and cross-domain requests. """
pass
class EnrollmentUserThrottle(UserRateThrottle):
"""Limit the number of requests users can make to the enrollment API."""
# TODO Limit significantly after performance testing. # pylint: disable=fixme
......@@ -267,7 +273,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
* user: The ID of the user.
"""
authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser
authentication_classes = OAuth2AuthenticationAllowInactiveUser, EnrollmentCrossDomainSessionAuth
permission_classes = ApiKeyHeaderPermissionIsAuthenticated,
throttle_classes = EnrollmentUserThrottle,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment