Commit 1a346415 by Toby Lawrence

Switch to header_control in contentserver.

parent 4ca2692f
......@@ -311,7 +311,7 @@ simplefilter('ignore')
MIDDLEWARE_CLASSES = (
'request_cache.middleware.RequestCache',
'clean_headers.middleware.CleanHeadersMiddleware',
'header_control.middleware.HeaderControlMiddleware',
'django.middleware.cache.UpdateCacheMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
......
......@@ -11,7 +11,7 @@ from django.http import (
from student.models import CourseEnrollment
from contentserver.models import CourseAssetCacheTtlConfig
from clean_headers import remove_headers_from_response
from header_control import force_header_for_response
from xmodule.assetstore.assetmgr import AssetManager
from xmodule.contentstore.content import StaticContent, XASSET_LOCATION_TAG
from xmodule.modulestore import InvalidLocationError
......@@ -153,7 +153,10 @@ class StaticContentServer(object):
response['Last-Modified'] = content.last_modified_at.strftime(HTTP_DATE_FORMAT)
remove_headers_from_response(response, "Vary")
# Force the Vary header to only vary responses on Origin, so that XHR and browser requests get cached
# separately and don't screw over one another. i.e. a browser request that doesn't send Origin, and
# caches a version of the response without CORS headers, in turn breaking XHR requests.
force_header_for_response(response, 'Vary', 'Origin')
@staticmethod
def get_expiration_value(now, cache_ttl):
......
......@@ -16,12 +16,13 @@ from mock import patch
from xmodule.contentstore.django import contentstore
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.xml_importer import import_course_from_xml
from contentserver.middleware import parse_range_header, HTTP_DATE_FORMAT, StaticContentServer
from student.models import CourseEnrollment
from student.tests.factories import UserFactory, AdminFactory
log = logging.getLogger(__name__)
......@@ -33,39 +34,44 @@ TEST_DATA_DIR = settings.COMMON_TEST_DATA_ROOT
@ddt.ddt
@override_settings(CONTENTSTORE=TEST_DATA_CONTENTSTORE)
class ContentStoreToyCourseTest(ModuleStoreTestCase):
class ContentStoreToyCourseTest(SharedModuleStoreTestCase):
"""
Tests that use the toy course.
"""
def setUp(self):
"""
Create user and login.
"""
self.staff_pwd = super(ContentStoreToyCourseTest, self).setUp()
self.staff_usr = self.user
self.non_staff_usr, self.non_staff_pwd = self.create_non_staff_user()
@classmethod
def setUpClass(cls):
super(ContentStoreToyCourseTest, cls).setUpClass()
self.client = Client()
self.contentstore = contentstore()
store = modulestore()._get_modulestore_by_type(ModuleStoreEnum.Type.mongo) # pylint: disable=protected-access
cls.contentstore = contentstore()
cls.modulestore = modulestore()
self.course_key = store.make_course_key('edX', 'toy', '2012_Fall')
cls.course_key = cls.modulestore.make_course_key('edX', 'toy', '2012_Fall')
import_course_from_xml(
store, self.user.id, TEST_DATA_DIR, ['toy'],
static_content_store=self.contentstore, verbose=True
cls.modulestore, 1, TEST_DATA_DIR, ['toy'],
static_content_store=cls.contentstore, verbose=True
)
# A locked asset
self.locked_asset = self.course_key.make_asset_key('asset', 'sample_static.txt')
self.url_locked = unicode(self.locked_asset)
self.contentstore.set_attr(self.locked_asset, 'locked', True)
cls.locked_asset = cls.course_key.make_asset_key('asset', 'sample_static.txt')
cls.url_locked = unicode(cls.locked_asset)
cls.contentstore.set_attr(cls.locked_asset, 'locked', True)
# An unlocked asset
self.unlocked_asset = self.course_key.make_asset_key('asset', 'another_static.txt')
self.url_unlocked = unicode(self.unlocked_asset)
self.length_unlocked = self.contentstore.get_attr(self.unlocked_asset, 'length')
cls.unlocked_asset = cls.course_key.make_asset_key('asset', 'another_static.txt')
cls.url_unlocked = unicode(cls.unlocked_asset)
cls.length_unlocked = cls.contentstore.get_attr(cls.unlocked_asset, 'length')
def setUp(self):
"""
Create user and login.
"""
super(ContentStoreToyCourseTest, self).setUp()
self.staff_usr = AdminFactory.create()
self.non_staff_usr = UserFactory.create()
self.client = Client()
def test_unlocked_asset(self):
"""
......@@ -89,7 +95,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
Test that locked assets behave appropriately in case user is logged in
in but not registered for the course.
"""
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd)
self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 403)
......@@ -101,7 +107,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd)
self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200)
......@@ -109,7 +115,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
"""
Test that locked assets behave appropriately in case user is staff.
"""
self.client.login(username=self.staff_usr, password=self.staff_pwd)
self.client.login(username=self.staff_usr, password='test')
resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200)
......@@ -191,6 +197,15 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
first=(self.length_unlocked), last=(self.length_unlocked)))
self.assertEqual(resp.status_code, 416)
def test_vary_header_sent(self):
"""
Tests that we're properly setting the Vary header to ensure browser requests don't get
cached in a way that breaks XHR requests to the same asset.
"""
resp = self.client.get(self.url_unlocked)
self.assertEqual(resp.status_code, 200)
self.assertEquals('Origin', resp['Vary'])
@patch('contentserver.models.CourseAssetCacheTtlConfig.get_cache_ttl')
def test_cache_headers_with_ttl_unlocked(self, mock_get_cache_ttl):
"""
......@@ -215,7 +230,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd)
self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200)
self.assertNotIn('Expires', resp)
......@@ -245,7 +260,7 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd)
self.client.login(username=self.non_staff_usr, password='test')
resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200)
self.assertNotIn('Expires', resp)
......@@ -256,20 +271,6 @@ class ContentStoreToyCourseTest(ModuleStoreTestCase):
near_expire_dt = StaticContentServer.get_expiration_value(start_dt, 55)
self.assertEqual("Thu, 01 Dec 1983 20:00:55 GMT", near_expire_dt)
def test_response_no_vary_header_unlocked(self):
resp = self.client.get(self.url_unlocked)
self.assertEqual(resp.status_code, 200)
self.assertNotIn('Vary', resp)
def test_response_no_vary_header_locked(self):
CourseEnrollment.enroll(self.non_staff_usr, self.course_key)
self.assertTrue(CourseEnrollment.is_enrolled(self.non_staff_usr, self.course_key))
self.client.login(username=self.non_staff_usr, password=self.non_staff_pwd)
resp = self.client.get(self.url_locked)
self.assertEqual(resp.status_code, 200)
self.assertNotIn('Vary', resp)
@ddt.ddt
class ParseRangeHeaderTestCase(unittest.TestCase):
......
......@@ -11,6 +11,7 @@ def remove_headers_from_response(response, *headers):
"""Removes the given headers from the response using the header_control middleware."""
response.remove_headers = headers
def force_header_for_response(response, header, value):
"""Forces the given header for the given response using the header_control middleware."""
force_headers = {}
......
......@@ -5,6 +5,7 @@ Middleware decorator for removing headers.
from functools import wraps
from header_control import remove_headers_from_response, force_header_for_response
def remove_headers(*headers):
"""
Decorator that removes specific headers from the response.
......
......@@ -15,20 +15,10 @@ class HeaderControlMiddleware(object):
Processes the given response, potentially remove or modifying headers.
"""
if len(getattr(response, 'remove_headers', [])) > 0:
for header in response.remove_headers:
try:
del response[header]
except KeyError:
pass
for header in getattr(response, 'remove_headers', []):
del response[header]
if len(getattr(response, 'force_headers', {})) > 0:
for header, value in response.force_headers.iteritems():
try:
del response[header]
except KeyError:
pass
response[header] = value
for header, value in getattr(response, 'force_headers', {}).iteritems():
response[header] = value
return response
......@@ -29,4 +29,4 @@ class TestForceHeader(TestCase):
wrapped_view = wrapper(fake_view)
response = wrapped_view(request)
self.assertEqual(len(response.force_headers), 1)
self.assertEqual(response.force_headers['Vary'], 'Origin')
\ No newline at end of file
self.assertEqual(response.force_headers['Vary'], 'Origin')
......@@ -11,6 +11,29 @@ class TestHeaderControlMiddlewareProcessResponse(TestCase):
super(TestHeaderControlMiddlewareProcessResponse, self).setUp()
self.middleware = HeaderControlMiddleware()
def test_doesnt_barf_if_not_modifying_anything(self):
fake_request = HttpRequest()
fake_response = HttpResponse()
fake_response['Vary'] = 'Cookie'
fake_response['Accept-Encoding'] = 'gzip'
result = self.middleware.process_response(fake_request, fake_response)
self.assertEquals('Cookie', result['Vary'])
self.assertEquals('gzip', result['Accept-Encoding'])
def test_doesnt_barf_removing_nonexistent_headers(self):
fake_request = HttpRequest()
fake_response = HttpResponse()
fake_response['Vary'] = 'Cookie'
fake_response['Accept-Encoding'] = 'gzip'
remove_headers_from_response(fake_response, 'Vary', 'FakeHeaderWeeee')
result = self.middleware.process_response(fake_request, fake_response)
self.assertNotIn('Vary', result)
self.assertEquals('gzip', result['Accept-Encoding'])
def test_removes_intended_headers(self):
fake_request = HttpRequest()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment