Commit 0836a896 by Andy Armstrong

Merge pull request #6189 from edx/andya/fix-masquerade

Fix Mako templates to always use updated request context
parents 5007e980 4d75c180
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distribuetd under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
...@@ -22,10 +22,24 @@ REQUEST_CONTEXT = threading.local() ...@@ -22,10 +22,24 @@ REQUEST_CONTEXT = threading.local()
class MakoMiddleware(object): class MakoMiddleware(object):
def process_request(self, request): def process_request(self, request):
REQUEST_CONTEXT.context = RequestContext(request) """ Process the middleware request. """
REQUEST_CONTEXT.context['is_secure'] = request.is_secure() REQUEST_CONTEXT.request = request
REQUEST_CONTEXT.context['site'] = safe_get_host(request)
def process_response(self, request, response): def process_response(self, __, response):
REQUEST_CONTEXT.context = None """ Process the middleware response. """
REQUEST_CONTEXT.request = None
return response return response
def get_template_request_context():
"""
Returns the template processing context to use for the current request,
or returns None if there is not a current request.
"""
request = getattr(REQUEST_CONTEXT, "request", None)
if not request:
return None
context = RequestContext(request)
context['is_secure'] = request.is_secure()
context['site'] = safe_get_host(request)
return context
...@@ -19,7 +19,7 @@ import logging ...@@ -19,7 +19,7 @@ import logging
from microsite_configuration import microsite from microsite_configuration import microsite
from edxmako import lookup_template from edxmako import lookup_template
import edxmako.middleware from edxmako.middleware import get_template_request_context
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -114,11 +114,12 @@ def render_to_string(template_name, dictionary, context=None, namespace='main'): ...@@ -114,11 +114,12 @@ def render_to_string(template_name, dictionary, context=None, namespace='main'):
context_instance['marketing_link'] = marketing_link context_instance['marketing_link'] = marketing_link
# In various testing contexts, there might not be a current request context. # In various testing contexts, there might not be a current request context.
if getattr(edxmako.middleware.REQUEST_CONTEXT, "context", None): request_context = get_template_request_context()
for d in edxmako.middleware.REQUEST_CONTEXT.context: if request_context:
context_dictionary.update(d) for item in request_context:
for d in context_instance: context_dictionary.update(item)
context_dictionary.update(d) for item in context_instance:
context_dictionary.update(item)
if context: if context:
context_dictionary.update(context) context_dictionary.update(context)
# fetch and render template # fetch and render template
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import edxmako
from django.conf import settings from django.conf import settings
from mako.template import Template as MakoTemplate from edxmako.middleware import get_template_request_context
from edxmako.shortcuts import marketing_link from edxmako.shortcuts import marketing_link
from mako.template import Template as MakoTemplate
import edxmako
import edxmako.middleware
DJANGO_VARIABLES = ['output_encoding', 'encoding_errors'] DJANGO_VARIABLES = ['output_encoding', 'encoding_errors']
...@@ -48,11 +48,12 @@ class Template(MakoTemplate): ...@@ -48,11 +48,12 @@ class Template(MakoTemplate):
context_dictionary = {} context_dictionary = {}
# In various testing contexts, there might not be a current request context. # In various testing contexts, there might not be a current request context.
if getattr(edxmako.middleware.REQUEST_CONTEXT, "context", None): request_context = get_template_request_context()
for d in edxmako.middleware.REQUEST_CONTEXT.context: if request_context:
context_dictionary.update(d) for item in request_context:
for d in context_instance: context_dictionary.update(item)
context_dictionary.update(d) for item in context_instance:
context_dictionary.update(item)
context_dictionary['settings'] = settings context_dictionary['settings'] = settings
context_dictionary['EDX_ROOT_URL'] = settings.EDX_ROOT_URL context_dictionary['EDX_ROOT_URL'] = settings.EDX_ROOT_URL
context_dictionary['django_context'] = context_instance context_dictionary['django_context'] = context_instance
......
...@@ -10,6 +10,7 @@ from django.test.utils import override_settings ...@@ -10,6 +10,7 @@ from django.test.utils import override_settings
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
import edxmako.middleware import edxmako.middleware
from edxmako.middleware import get_template_request_context
from edxmako import add_lookup, LOOKUP from edxmako import add_lookup, LOOKUP
from edxmako.shortcuts import ( from edxmako.shortcuts import (
marketing_link, marketing_link,
...@@ -83,11 +84,11 @@ class MakoMiddlewareTest(TestCase): ...@@ -83,11 +84,11 @@ class MakoMiddlewareTest(TestCase):
self.middleware.process_request(self.request) self.middleware.process_request(self.request)
# requestcontext should not be None. # requestcontext should not be None.
self.assertIsNotNone(edxmako.middleware.REQUEST_CONTEXT.context) self.assertIsNotNone(get_template_request_context())
self.middleware.process_response(self.request, self.response) self.middleware.process_response(self.request, self.response)
# requestcontext should be None. # requestcontext should be None.
self.assertIsNone(edxmako.middleware.REQUEST_CONTEXT.context) self.assertIsNone(get_template_request_context())
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
@patch("edxmako.middleware.REQUEST_CONTEXT") @patch("edxmako.middleware.REQUEST_CONTEXT")
......
...@@ -4,6 +4,7 @@ Base class for pages in courseware. ...@@ -4,6 +4,7 @@ Base class for pages in courseware.
from bok_choy.page_object import PageObject from bok_choy.page_object import PageObject
from . import BASE_URL from . import BASE_URL
from .tab_nav import TabNavPage
class CoursePage(PageObject): class CoursePage(PageObject):
...@@ -29,3 +30,11 @@ class CoursePage(PageObject): ...@@ -29,3 +30,11 @@ class CoursePage(PageObject):
Construct a URL to the page within the course. Construct a URL to the page within the course.
""" """
return BASE_URL + "/courses/" + self.course_id + "/" + self.url_path return BASE_URL + "/courses/" + self.course_id + "/" + self.url_path
def has_tab(self, tab_name):
"""
Returns true if the current page is showing a tab with the given name.
:return:
"""
tab_nav = TabNavPage(self.browser)
return tab_name in tab_nav.tab_names
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
Staff view of courseware Staff view of courseware
""" """
from bok_choy.page_object import PageObject from bok_choy.page_object import PageObject
from .courseware import CoursewarePage
class StaffPage(PageObject): class StaffPage(CoursewarePage):
""" """
View of courseware pages while logged in as course staff View of courseware pages while logged in as course staff
""" """
...@@ -13,6 +14,8 @@ class StaffPage(PageObject): ...@@ -13,6 +14,8 @@ class StaffPage(PageObject):
STAFF_STATUS_CSS = '#staffstatus' STAFF_STATUS_CSS = '#staffstatus'
def is_browser_on_page(self): def is_browser_on_page(self):
if not super(StaffPage, self).is_browser_on_page():
return False
return self.q(css=self.STAFF_STATUS_CSS).present return self.q(css=self.STAFF_STATUS_CSS).present
@property @property
......
...@@ -48,7 +48,7 @@ class TabNavPage(PageObject): ...@@ -48,7 +48,7 @@ class TabNavPage(PageObject):
Return the CSS to click for `tab_name`. Return the CSS to click for `tab_name`.
If no tabs exist for that name, return `None`. If no tabs exist for that name, return `None`.
""" """
all_tabs = self._tab_names all_tabs = self.tab_names
try: try:
tab_index = all_tabs.index(tab_name) tab_index = all_tabs.index(tab_name)
...@@ -58,7 +58,7 @@ class TabNavPage(PageObject): ...@@ -58,7 +58,7 @@ class TabNavPage(PageObject):
return 'ol.course-tabs li:nth-of-type({0}) a'.format(tab_index + 1) return 'ol.course-tabs li:nth-of-type({0}) a'.format(tab_index + 1)
@property @property
def _tab_names(self): def tab_names(self):
""" """
Return the list of available tab names. If no tab names Return the list of available tab names. If no tab names
are available, wait for them to load. Raises a `BrokenPromiseError` are available, wait for them to load. Raises a `BrokenPromiseError`
......
...@@ -11,15 +11,15 @@ from ...fixtures.course import CourseFixture, XBlockFixtureDesc ...@@ -11,15 +11,15 @@ from ...fixtures.course import CourseFixture, XBlockFixtureDesc
from textwrap import dedent from textwrap import dedent
class StaffDebugTest(UniqueCourseTest): class StaffViewTest(UniqueCourseTest):
""" """
Tests that verify the staff debug info. Tests that verify the staff view.
""" """
USERNAME = "STAFF_TESTER" USERNAME = "STAFF_TESTER"
EMAIL = "johndoe@example.com" EMAIL = "johndoe@example.com"
def setUp(self): def setUp(self):
super(StaffDebugTest, self).setUp() super(StaffViewTest, self).setUp()
self.courseware_page = CoursewarePage(self.browser, self.course_id) self.courseware_page = CoursewarePage(self.browser, self.course_id)
...@@ -59,10 +59,31 @@ class StaffDebugTest(UniqueCourseTest): ...@@ -59,10 +59,31 @@ class StaffDebugTest(UniqueCourseTest):
Open staff page with assertion Open staff page with assertion
""" """
self.courseware_page.visit() self.courseware_page.visit()
staff_page = StaffPage(self.browser) staff_page = StaffPage(self.browser, self.course_id)
self.assertEqual(staff_page.staff_status, 'Staff view') self.assertEqual(staff_page.staff_status, 'Staff view')
return staff_page return staff_page
class StaffViewToggleTest(StaffViewTest):
"""
Tests for the staff view toggle button.
"""
def test_instructor_tab_visibility(self):
"""
Test that the instructor tab is hidden when viewing as a student.
"""
course_page = self._goto_staff_page()
self.assertTrue(course_page.has_tab('Instructor'))
course_page.toggle_staff_view()
self.assertEqual(course_page.staff_status, 'Student view')
self.assertFalse(course_page.has_tab('Instructor'))
class StaffDebugTest(StaffViewTest):
"""
Tests that verify the staff debug info.
"""
def test_reset_attempts_empty(self): def test_reset_attempts_empty(self):
""" """
Test that we reset even when there is no student state Test that we reset even when there is no student state
......
...@@ -678,7 +678,7 @@ class UnitPublishingTest(ContainerBase): ...@@ -678,7 +678,7 @@ class UnitPublishingTest(ContainerBase):
""" """
Verifies that the browser is on the staff page and returns a StaffPage. Verifies that the browser is on the staff page and returns a StaffPage.
""" """
page = StaffPage(self.browser) page = StaffPage(self.browser, self.course_id)
EmptyPromise(page.is_browser_on_page, 'Browser is on staff page in LMS').fulfill() EmptyPromise(page.is_browser_on_page, 'Browser is on staff page in LMS').fulfill()
return page return page
......
...@@ -718,7 +718,7 @@ class StaffLockTest(CourseOutlineTest): ...@@ -718,7 +718,7 @@ class StaffLockTest(CourseOutlineTest):
courseware = CoursewarePage(self.browser, self.course_id) courseware = CoursewarePage(self.browser, self.course_id)
courseware.wait_for_page() courseware.wait_for_page()
self.assertEqual(courseware.num_sections, 2) self.assertEqual(courseware.num_sections, 2)
StaffPage(self.browser).toggle_staff_view() StaffPage(self.browser, self.course_id).toggle_staff_view()
self.assertEqual(courseware.num_sections, 1) self.assertEqual(courseware.num_sections, 1)
def test_locked_subsections_do_not_appear_in_lms(self): def test_locked_subsections_do_not_appear_in_lms(self):
...@@ -737,7 +737,7 @@ class StaffLockTest(CourseOutlineTest): ...@@ -737,7 +737,7 @@ class StaffLockTest(CourseOutlineTest):
courseware = CoursewarePage(self.browser, self.course_id) courseware = CoursewarePage(self.browser, self.course_id)
courseware.wait_for_page() courseware.wait_for_page()
self.assertEqual(courseware.num_subsections, 2) self.assertEqual(courseware.num_subsections, 2)
StaffPage(self.browser).toggle_staff_view() StaffPage(self.browser, self.course_id).toggle_staff_view()
self.assertEqual(courseware.num_subsections, 1) self.assertEqual(courseware.num_subsections, 1)
def test_toggling_staff_lock_on_section_does_not_publish_draft_units(self): def test_toggling_staff_lock_on_section_does_not_publish_draft_units(self):
......
...@@ -92,7 +92,7 @@ def make_track_function(request): ...@@ -92,7 +92,7 @@ def make_track_function(request):
return function return function
def toc_for_course(user, request, course, active_chapter, active_section, field_data_cache): def toc_for_course(request, course, active_chapter, active_section, field_data_cache):
''' '''
Create a table of contents from the module store Create a table of contents from the module store
...@@ -117,7 +117,7 @@ def toc_for_course(user, request, course, active_chapter, active_section, field_ ...@@ -117,7 +117,7 @@ def toc_for_course(user, request, course, active_chapter, active_section, field_
''' '''
with modulestore().bulk_operations(course.id): with modulestore().bulk_operations(course.id):
course_module = get_module_for_descriptor(user, request, course, field_data_cache, course.id) course_module = get_module_for_descriptor(request.user, request, course, field_data_cache, course.id)
if course_module is None: if course_module is None:
return None return None
......
...@@ -394,7 +394,7 @@ class TestTOC(ModuleStoreTestCase): ...@@ -394,7 +394,7 @@ class TestTOC(ModuleStoreTestCase):
with check_mongo_calls(toc_finds): with check_mongo_calls(toc_finds):
actual = render.toc_for_course( actual = render.toc_for_course(
self.request.user, self.request, self.toy_course, self.chapter, None, self.field_data_cache self.request, self.toy_course, self.chapter, None, self.field_data_cache
) )
for toc_section in expected: for toc_section in expected:
self.assertIn(toc_section, actual) self.assertIn(toc_section, actual)
...@@ -432,7 +432,9 @@ class TestTOC(ModuleStoreTestCase): ...@@ -432,7 +432,9 @@ class TestTOC(ModuleStoreTestCase):
'url_name': 'secret:magic', 'display_name': 'secret:magic'}]) 'url_name': 'secret:magic', 'display_name': 'secret:magic'}])
with check_mongo_calls(toc_finds): with check_mongo_calls(toc_finds):
actual = render.toc_for_course(self.request.user, self.request, self.toy_course, self.chapter, section, self.field_data_cache) actual = render.toc_for_course(
self.request, self.toy_course, self.chapter, section, self.field_data_cache
)
for toc_section in expected: for toc_section in expected:
self.assertIn(toc_section, actual) self.assertIn(toc_section, actual)
......
...@@ -555,7 +555,8 @@ class TestAccordionDueDate(BaseDueDateTests): ...@@ -555,7 +555,8 @@ class TestAccordionDueDate(BaseDueDateTests):
def get_text(self, course): def get_text(self, course):
""" Returns the HTML for the accordion """ """ Returns the HTML for the accordion """
return views.render_accordion( return views.render_accordion(
self.request, course, course.get_children()[0].scope_ids.usage_id.to_deprecated_string(), None, None self.request, course, course.get_children()[0].scope_ids.usage_id.to_deprecated_string(),
None, None
) )
......
...@@ -118,9 +118,7 @@ def render_accordion(request, course, chapter, section, field_data_cache): ...@@ -118,9 +118,7 @@ def render_accordion(request, course, chapter, section, field_data_cache):
Returns the html string Returns the html string
""" """
# grab the table of contents # grab the table of contents
user = User.objects.prefetch_related("groups").get(id=request.user.id) toc = toc_for_course(request, course, chapter, section, field_data_cache)
request.user = user # keep just one instance of User
toc = toc_for_course(user, request, course, chapter, section, field_data_cache)
context = dict([ context = dict([
('toc', toc), ('toc', toc),
...@@ -325,10 +323,15 @@ def index(request, course_id, chapter=None, section=None, ...@@ -325,10 +323,15 @@ def index(request, course_id, chapter=None, section=None,
request.user = user # keep just one instance of User request.user = user # keep just one instance of User
with modulestore().bulk_operations(course_key): with modulestore().bulk_operations(course_key):
return _index_bulk_op(request, user, course_key, chapter, section, position) return _index_bulk_op(request, course_key, chapter, section, position)
def _index_bulk_op(request, user, course_key, chapter, section, position): # pylint: disable=too-many-statements
def _index_bulk_op(request, course_key, chapter, section, position):
"""
Render the index page for the specified course.
"""
user = request.user
course = get_course_with_access(user, 'load', course_key, depth=2) course = get_course_with_access(user, 'load', course_key, depth=2)
staff_access = has_access(user, 'staff', course) staff_access = has_access(user, 'staff', course)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment