Commit 1a346415 by Toby Lawrence

Switch to header_control in contentserver.

parent 4ca2692f
...@@ -311,7 +311,7 @@ simplefilter('ignore') ...@@ -311,7 +311,7 @@ simplefilter('ignore')
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
'request_cache.middleware.RequestCache', 'request_cache.middleware.RequestCache',
'clean_headers.middleware.CleanHeadersMiddleware', 'header_control.middleware.HeaderControlMiddleware',
'django.middleware.cache.UpdateCacheMiddleware', 'django.middleware.cache.UpdateCacheMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.csrf.CsrfViewMiddleware',
......
...@@ -11,7 +11,7 @@ from django.http import ( ...@@ -11,7 +11,7 @@ from django.http import (
from student.models import CourseEnrollment from student.models import CourseEnrollment
from contentserver.models import CourseAssetCacheTtlConfig from contentserver.models import CourseAssetCacheTtlConfig
from clean_headers import remove_headers_from_response from header_control import force_header_for_response
from xmodule.assetstore.assetmgr import AssetManager from xmodule.assetstore.assetmgr import AssetManager
from xmodule.contentstore.content import StaticContent, XASSET_LOCATION_TAG from xmodule.contentstore.content import StaticContent, XASSET_LOCATION_TAG
from xmodule.modulestore import InvalidLocationError from xmodule.modulestore import InvalidLocationError
...@@ -153,7 +153,10 @@ class StaticContentServer(object): ...@@ -153,7 +153,10 @@ class StaticContentServer(object):
response['Last-Modified'] = content.last_modified_at.strftime(HTTP_DATE_FORMAT) response['Last-Modified'] = content.last_modified_at.strftime(HTTP_DATE_FORMAT)
remove_headers_from_response(response, "Vary") # Force the Vary header to only vary responses on Origin, so that XHR and browser requests get cached
# separately and don't screw over one another. i.e. a browser request that doesn't send Origin, and
# caches a version of the response without CORS headers, in turn breaking XHR requests.
force_header_for_response(response, 'Vary', 'Origin')
@staticmethod @staticmethod
def get_expiration_value(now, cache_ttl): def get_expiration_value(now, cache_ttl):
......
...@@ -16,12 +16,13 @@ from mock import patch ...@@ -16,12 +16,13 @@ from mock import patch
from xmodule.contentstore.django import contentstore from xmodule.contentstore.django import contentstore
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.xml_importer import import_course_from_xml from xmodule.modulestore.xml_importer import import_course_from_xml
from contentserver.middleware import parse_range_header, HTTP_DATE_FORMAT, StaticContentServer from contentserver.middleware import parse_range_header, HTTP_DATE_FORMAT, StaticContentServer
from student.models import CourseEnrollment from student.models import CourseEnrollment
from student.tests.factories import UserFactory, AdminFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -33,39 +34,44 @@ TEST_DATA_DIR = settings.COMMON_TEST_DATA_ROOT ...@@ -33,39 +34,44 @@ TEST_DATA_DIR = settings.COMMON_TEST_DATA_ROOT
@ddt.ddt @ddt.ddt
@override_settings(CONTENTSTORE=TEST_DATA_CONTENTSTORE) @override_settings(CONTENTSTORE=TEST_DATA_CONTENTSTORE)
class ContentStoreToyCourseTest(ModuleStoreTestCase): class ContentStoreToyCourseTest(SharedModuleStoreTestCase):
""" """
Tests that use the toy course. Tests that use the toy course.
""" """
def setUp(self): @classmethod
""" def setUpClass(cls):
Create user and login. super(ContentStoreToyCourseTest, cls).setUpClass()
"""
self.staff_pwd = super(ContentStoreToyCourseTest, self).setUp()
self.staff_usr = self.user
self.non_staff_usr, self.non_staff_pwd = self.create_non_staff_user()
self.client = Client() cls.contentstore = contentstore()
self.contentstore = contentstore() cls.modulestore = modulestore()
store = modulestore()._get_modulestore_by_type(ModuleStoreEnum.Type.mongo) # pylint: disable=protected-access
self.course_key = store.make_course_key('edX', 'toy', '2012_Fall') cls.course_key = cls.modulestore.make_course_key('edX', 'toy', '2012_Fall')
import_course_from_xml( import_course_from_xml(
store, self.user.id, TEST_DATA_DIR, ['toy'], cls.modulestore, 1, TEST_DATA_DIR, ['toy'],
static_content_store=self.contentstore, verbose=True static_content_store=cls.contentstore, verbose=True
) )
# A locked asset # A locked asset
self.locked_asset = self.course_key.make_asset_key('asset', 'sample_static.txt') cls.locked_asset = cls.course_key.make_asset_key('asset', 'sample_static.txt')
self.url_locked = unicode(self.locked_asset) cls.url_locked = unicode(cls.locked_asset)
self.contentstore.set_attr(self.locked_asset, 'locked', True) cls.contentstore.set_attr(cls.locked_asset, 'locked', True)
# An unlocked asset # An unlocked asset
self.unlocked_asset = self.course_key.make_asset_key('asset', 'another_static.txt') cls.unlocked_asset = cls.course_key.make_asset_key('asset', 'another_static.txt')
self.url_unlocked = unicode(self.unlocked_asset) cls.url_unlocked = unicode(cls.unlocked_asset)
self.length_unlocked = self.contentstore.get_attr(self.unlocked_asset, 'length') cls.length_unlocked = cls.contentstore.get_attr(cls.unlocked_asset, 'length')
def setUp(self):
"""
Create user and login.
"""
super(ContentStoreToyCourseTest, self).setUp()
self.staff_usr = AdminFactory.create()
self.non_staff_usr = UserFactory.create()
self.client = Client()
def test_unlocked_asset(self): def test_unlocked_asset(self):
""" """
...@@ -89,7 +95,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): ...@@ -89,7 +95,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
Test that locked assets behave appropriately in case user is logged in Test that locked assets behave appropriately in case user is logged in
in but not registered for the course. in but not registered for the course.
""" """
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked) resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 403) self.assertEqual(resp.status_code, 403)
...@@ -101,7 +107,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): ...@@ -101,7 +107,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key) CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key)) self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked) resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
...@@ -109,7 +115,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): ...@@ -109,7 +115,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
""" """
Test that locked assets behave appropriately in case user is staff. Test that locked assets behave appropriately in case user is staff.
""" """
self.client.login(username=self.staff_usr, password=self.staff_pwd) self.client.login(username=self.staff_usr, password='test')
resp = self.client.get(self.url_locked) resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
...@@ -191,6 +197,15 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): ...@@ -191,6 +197,15 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
first=(self.length_unlocked), last=(self.length_unlocked))) first=(self.length_unlocked), last=(self.length_unlocked)))
self.assertEqual(resp.status_code, 416) self.assertEqual(resp.status_code, 416)
def test_vary_header_sent(self):
"""
Tests that we're properly setting the Vary header to ensure browser requests don't get
cached in a way that breaks XHR requests to the same asset.
"""
resp = self.client.get(self.url_unlocked)
self.assertEqual(resp.status_code, 200)
self.assertEquals('Origin', resp['Vary'])
@patch('contentserver.models.CourseAssetCacheTtlConfig.get_cache_ttl') @patch('contentserver.models.CourseAssetCacheTtlConfig.get_cache_ttl')
def test_cache_headers_with_ttl_unlocked(self, mock_get_cache_ttl): def test_cache_headers_with_ttl_unlocked(self, mock_get_cache_ttl):
""" """
...@@ -215,7 +230,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): ...@@ -215,7 +230,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key) CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key)) self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked) resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
self.assertNotIn('Expires', resp) self.assertNotIn('Expires', resp)
...@@ -245,7 +260,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): ...@@ -245,7 +260,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key) CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key)) self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd) self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked) resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
self.assertNotIn('Expires', resp) self.assertNotIn('Expires', resp)
...@@ -256,20 +271,6 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase): ...@@ -256,20 +271,6 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
near_expire_dt = StaticContentServer.get_expiration_value(start_dt, 55) near_expire_dt = StaticContentServer.get_expiration_value(start_dt, 55)
self.assertEqual("Thu, 01 Dec 1983 20:00:55 GMT", near_expire_dt) self.assertEqual("Thu, 01 Dec 1983 20:00:55 GMT", near_expire_dt)
def test_response_no_vary_header_unlocked(self):
resp = self.client.get(self.url_unlocked)
self.assertEqual(resp.status_code, 200)
self.assertNotIn('Vary', resp)
def test_response_no_vary_header_locked(self):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd)
resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200)
self.assertNotIn('Vary', resp)
@ddt.ddt @ddt.ddt
class ParseRangeHeaderTestCase(unittest.TestCase): class ParseRangeHeaderTestCase(unittest.TestCase):
......
...@@ -11,6 +11,7 @@ def remove_headers_from_response(response, *headers): ...@@ -11,6 +11,7 @@ def remove_headers_from_response(response, *headers):
"""Removes the given headers from the response using the header_control middleware.""" """Removes the given headers from the response using the header_control middleware."""
response.remove_headers = headers response.remove_headers = headers
def force_header_for_response(response, header, value): def force_header_for_response(response, header, value):
"""Forces the given header for the given response using the header_control middleware.""" """Forces the given header for the given response using the header_control middleware."""
force_headers = {} force_headers = {}
......
...@@ -5,6 +5,7 @@ Middleware decorator for removing headers. ...@@ -5,6 +5,7 @@ Middleware decorator for removing headers.
from functools import wraps from functools import wraps
from header_control import remove_headers_from_response, force_header_for_response from header_control import remove_headers_from_response, force_header_for_response
def remove_headers(*headers): def remove_headers(*headers):
""" """
Decorator that removes specific headers from the response. Decorator that removes specific headers from the response.
......
...@@ -15,20 +15,10 @@ class HeaderControlMiddleware(object): ...@@ -15,20 +15,10 @@ class HeaderControlMiddleware(object):
Processes the given response, potentially remove or modifying headers. Processes the given response, potentially remove or modifying headers.
""" """
if len(getattr(response, 'remove_headers', [])) > 0: for header in getattr(response, 'remove_headers', []):
for header in response.remove_headers: del response[header]
try:
del response[header]
except KeyError:
pass
if len(getattr(response, 'force_headers', {})) > 0: for header, value in getattr(response, 'force_headers', {}).iteritems():
for header, value in response.force_headers.iteritems(): response[header] = value
try:
del response[header]
except KeyError:
pass
response[header] = value
return response return response
...@@ -29,4 +29,4 @@ class TestForceHeader(TestCase): ...@@ -29,4 +29,4 @@ class TestForceHeader(TestCase):
wrapped_view = wrapper(fake_view) wrapped_view = wrapper(fake_view)
response = wrapped_view(request) response = wrapped_view(request)
self.assertEqual(len(response.force_headers), 1) self.assertEqual(len(response.force_headers), 1)
self.assertEqual(response.force_headers['Vary'], 'Origin') self.assertEqual(response.force_headers['Vary'], 'Origin')
\ No newline at end of file
...@@ -11,6 +11,29 @@ class TestHeaderControlMiddlewareProcessResponse(TestCase): ...@@ -11,6 +11,29 @@ class TestHeaderControlMiddlewareProcessResponse(TestCase):
super(TestHeaderControlMiddlewareProcessResponse, self).setUp() super(TestHeaderControlMiddlewareProcessResponse, self).setUp()
self.middleware = HeaderControlMiddleware() self.middleware = HeaderControlMiddleware()
def test_doesnt_barf_if_not_modifying_anything(self):
fake_request = HttpRequest()
fake_response = HttpResponse()
fake_response['Vary'] = 'Cookie'
fake_response['Accept-Encoding'] = 'gzip'
result = self.middleware.process_response(fake_request, fake_response)
self.assertEquals('Cookie', result['Vary'])
self.assertEquals('gzip', result['Accept-Encoding'])
def test_doesnt_barf_removing_nonexistent_headers(self):
fake_request = HttpRequest()
fake_response = HttpResponse()
fake_response['Vary'] = 'Cookie'
fake_response['Accept-Encoding'] = 'gzip'
remove_headers_from_response(fake_response, 'Vary', 'FakeHeaderWeeee')
result = self.middleware.process_response(fake_request, fake_response)
self.assertNotIn('Vary', result)
self.assertEquals('gzip', result['Accept-Encoding'])
def test_removes_intended_headers(self): def test_removes_intended_headers(self):
fake_request = HttpRequest() fake_request = HttpRequest()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment