Commit 2e88e338 by Waheed Ahmed

Refactor verify course id decorator and fixed tests.

parent 8a7fef07
...@@ -188,7 +188,7 @@ class TestLTIModuleListing(ModuleStoreTestCase): ...@@ -188,7 +188,7 @@ class TestLTIModuleListing(ModuleStoreTestCase):
"""tests that the draft lti module is part of the endpoint response""" """tests that the draft lti module is part of the endpoint response"""
request = mock.Mock() request = mock.Mock()
request.method = 'GET' request.method = 'GET'
response = get_course_lti_endpoints(request, self.course.id.to_deprecated_string()) response = get_course_lti_endpoints(request, course_id=self.course.id.to_deprecated_string())
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertEqual('application/json', response['Content-Type']) self.assertEqual('application/json', response['Content-Type'])
......
...@@ -463,7 +463,7 @@ class TestProgressDueDate(BaseDueDateTests): ...@@ -463,7 +463,7 @@ class TestProgressDueDate(BaseDueDateTests):
""" Returns the HTML for the progress page """ """ Returns the HTML for the progress page """
mako_middleware_process_request(self.request) mako_middleware_process_request(self.request)
return views.progress(self.request, course.id.to_deprecated_string(), self.user.id).content return views.progress(self.request, course_id=course.id.to_deprecated_string(), student_id=self.user.id).content
class TestAccordionDueDate(BaseDueDateTests): class TestAccordionDueDate(BaseDueDateTests):
...@@ -560,11 +560,11 @@ class ProgressPageTests(ModuleStoreTestCase): ...@@ -560,11 +560,11 @@ class ProgressPageTests(ModuleStoreTestCase):
def test_pure_ungraded_xblock(self): def test_pure_ungraded_xblock(self):
ItemFactory(category='acid', parent_location=self.vertical.location) ItemFactory(category='acid', parent_location=self.vertical.location)
resp = views.progress(self.request, self.course.id.to_deprecated_string()) resp = views.progress(self.request, course_id=self.course.id.to_deprecated_string())
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
def test_non_asci_grade_cutoffs(self): def test_non_asci_grade_cutoffs(self):
resp = views.progress(self.request, self.course.id.to_deprecated_string()) resp = views.progress(self.request, course_id=self.course.id.to_deprecated_string())
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
...@@ -581,11 +581,11 @@ class TestVerifyCourseIdDecorator(TestCase): ...@@ -581,11 +581,11 @@ class TestVerifyCourseIdDecorator(TestCase):
def test_decorator_with_valid_course_id(self): def test_decorator_with_valid_course_id(self):
mocked_view = create_autospec(views.course_about) mocked_view = create_autospec(views.course_about)
view_function = views.verify_course_id(mocked_view) view_function = views.verify_course_id(mocked_view)
view_function(self.request, self.valid_course_id) view_function(self.request, course_id=self.valid_course_id)
self.assertTrue(mocked_view.called) self.assertTrue(mocked_view.called)
def test_decorator_with_invalid_course_id(self): def test_decorator_with_invalid_course_id(self):
mocked_view = create_autospec(views.course_about) mocked_view = create_autospec(views.course_about)
view_function = views.verify_course_id(mocked_view) view_function = views.verify_course_id(mocked_view)
self.assertRaises(Http404, view_function, self.request, self.invalid_course_id) self.assertRaises(Http404, view_function, self.request, course_id=self.invalid_course_id)
self.assertFalse(mocked_view.called) self.assertFalse(mocked_view.called)
...@@ -84,17 +84,18 @@ def user_groups(user): ...@@ -84,17 +84,18 @@ def user_groups(user):
def verify_course_id(view_func): def verify_course_id(view_func):
""" """
This decorator should only be used with views whose second argument is course_id. This decorator should only be used with views whose kwargs must contain course_id.
If course_id is not valid raise 404. If course_id is not valid raise 404.
""" """
@wraps(view_func) @wraps(view_func)
def _decorated(request, course_id, *args, **kwargs): def _decorated(request, *args, **kwargs):
course_id = kwargs.get("course_id")
try: try:
SlashSeparatedCourseKey.from_deprecated_string(course_id) SlashSeparatedCourseKey.from_deprecated_string(course_id)
except InvalidKeyError: except InvalidKeyError:
raise Http404 raise Http404
response = view_func(request, course_id, *args, **kwargs) response = view_func(request, *args, **kwargs)
return response return response
return _decorated return _decorated
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment