Commit d43ffd3a by Phil McGachey

[LTI Provider] Refactoring and clean-up

parent d94c0ab1
# -*- coding: utf-8 -*-
# pylint: disable=invalid-name, missing-docstring, unused-argument, unused-import, line-too-long
from south.utils import datetime_utils as datetime
from south.db import db
from south.v2 import SchemaMigration
from django.db import models
class Migration(SchemaMigration):
def forwards(self, orm):
# Adding model 'LtiConsumer'
db.create_table('lti_provider_lticonsumer', (
('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
('consumer_name', self.gf('django.db.models.fields.CharField')(max_length=255)),
('consumer_key', self.gf('django.db.models.fields.CharField')(unique=True, max_length=32, db_index=True)),
('consumer_secret', self.gf('django.db.models.fields.CharField')(unique=True, max_length=32)),
))
db.send_create_signal('lti_provider', ['LtiConsumer'])
def backwards(self, orm):
# Deleting model 'LtiConsumer'
db.delete_table('lti_provider_lticonsumer')
models = {
'lti_provider.lticonsumer': {
'Meta': {'object_name': 'LtiConsumer'},
'consumer_key': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '32', 'db_index': 'True'}),
'consumer_name': ('django.db.models.fields.CharField', [], {'max_length': '255'}),
'consumer_secret': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '32'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'})
}
}
complete_apps = ['lti_provider']
\ No newline at end of file
""" """
Database models for the LTI provider feature. Database models for the LTI provider feature.
This app uses migrations. If you make changes to this model, be sure to create
an appropriate migration file and check it in at the same time as your model
changes. To do that,
1. Go to the edx-platform dir
2. ./manage.py lms schemamigration lti_provider --auto "description" --settings=devstack
""" """
from django.db import models from django.db import models
from django.dispatch import receiver from django.dispatch import receiver
...@@ -13,8 +20,9 @@ class LtiConsumer(models.Model): ...@@ -13,8 +20,9 @@ class LtiConsumer(models.Model):
specific settings, such as the OAuth key/secret pair and any LTI fields specific settings, such as the OAuth key/secret pair and any LTI fields
that must be persisted. that must be persisted.
""" """
key = models.CharField(max_length=32, unique=True, db_index=True) consumer_name = models.CharField(max_length=255)
secret = models.CharField(max_length=32, unique=True) consumer_key = models.CharField(max_length=32, unique=True, db_index=True)
consumer_secret = models.CharField(max_length=32, unique=True)
@receiver(SCORE_CHANGED) @receiver(SCORE_CHANGED)
......
...@@ -79,7 +79,7 @@ class SignatureValidator(RequestValidator): ...@@ -79,7 +79,7 @@ class SignatureValidator(RequestValidator):
:return: True if the key is valid, False if it is not. :return: True if the key is valid, False if it is not.
""" """
return LtiConsumer.objects.filter(key=client_key).count() == 1 return LtiConsumer.objects.filter(consumer_key=client_key).count() == 1
def get_client_secret(self, client_key, request): def get_client_secret(self, client_key, request):
""" """
...@@ -90,7 +90,7 @@ class SignatureValidator(RequestValidator): ...@@ -90,7 +90,7 @@ class SignatureValidator(RequestValidator):
present, or None if the key does not exist in the database. present, or None if the key does not exist in the database.
""" """
try: try:
return LtiConsumer.objects.get(key=client_key).secret return LtiConsumer.objects.get(consumer_key=client_key).consumer_secret
except ObjectDoesNotExist: except ObjectDoesNotExist:
return None return None
......
...@@ -78,7 +78,7 @@ class SignatureValidatorTest(TestCase): ...@@ -78,7 +78,7 @@ class SignatureValidatorTest(TestCase):
Verify that validate_client_key succeeds if the client key exists in the Verify that validate_client_key succeeds if the client key exists in the
database database
""" """
LtiConsumer.objects.create(key='client_key', secret='client_secret') LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret')
self.assertTrue(SignatureValidator().validate_client_key('client_key', None)) self.assertTrue(SignatureValidator().validate_client_key('client_key', None))
def test_validate_missing_key(self): def test_validate_missing_key(self):
...@@ -93,7 +93,7 @@ class SignatureValidatorTest(TestCase): ...@@ -93,7 +93,7 @@ class SignatureValidatorTest(TestCase):
Verify that get_client_secret returns the right value if the key is in Verify that get_client_secret returns the right value if the key is in
the database the database
""" """
LtiConsumer.objects.create(key='client_key', secret='client_secret') LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret')
secret = SignatureValidator().get_client_secret('client_key', None) secret = SignatureValidator().get_client_secret('client_key', None)
self.assertEqual(secret, 'client_secret') self.assertEqual(secret, 'client_secret')
......
...@@ -8,6 +8,7 @@ from mock import patch, MagicMock ...@@ -8,6 +8,7 @@ from mock import patch, MagicMock
from lti_provider import views from lti_provider import views
from lti_provider.signature_validator import SignatureValidator from lti_provider.signature_validator import SignatureValidator
from opaque_keys.edx.keys import CourseKey, UsageKey
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
...@@ -22,10 +23,11 @@ LTI_DEFAULT_PARAMS = { ...@@ -22,10 +23,11 @@ LTI_DEFAULT_PARAMS = {
'oauth_nonce': u'OAuth Nonce', 'oauth_nonce': u'OAuth Nonce',
} }
COURSE_KEY = CourseKey.from_string('some/course/id')
USAGE_KEY = UsageKey.from_string('i4x://some/course/problem/uuid').map_into_course(COURSE_KEY)
COURSE_PARAMS = { COURSE_PARAMS = {
'course_id': 'CourseID', 'course_key': COURSE_KEY,
'usage_id': 'UsageID' 'usage_key': USAGE_KEY
} }
...@@ -72,7 +74,7 @@ class LtiLaunchTest(TestCase): ...@@ -72,7 +74,7 @@ class LtiLaunchTest(TestCase):
Verifies that the LTI launch succeeds when passed a valid request. Verifies that the LTI launch succeeds when passed a valid request.
""" """
request = build_launch_request() request = build_launch_request()
views.lti_launch(request, COURSE_PARAMS['course_id'], COURSE_PARAMS['usage_id']) views.lti_launch(request, str(COURSE_PARAMS['course_key']), str(COURSE_PARAMS['usage_key']))
render.assert_called_with(request, ALL_PARAMS) render.assert_called_with(request, ALL_PARAMS)
def launch_with_missing_parameter(self, missing_param): def launch_with_missing_parameter(self, missing_param):
...@@ -112,10 +114,10 @@ class LtiLaunchTest(TestCase): ...@@ -112,10 +114,10 @@ class LtiLaunchTest(TestCase):
properly stored in the session properly stored in the session
""" """
request = build_launch_request() request = build_launch_request()
views.lti_launch(request, COURSE_PARAMS['course_id'], COURSE_PARAMS['usage_id']) views.lti_launch(request, str(COURSE_PARAMS['course_key']), str(COURSE_PARAMS['usage_key']))
session = request.session[views.LTI_SESSION_KEY] session = request.session[views.LTI_SESSION_KEY]
self.assertEqual(session['course_id'], 'CourseID', 'Course ID not set in the session') self.assertEqual(session['course_key'], COURSE_KEY, 'Course key not set in the session')
self.assertEqual(session['usage_id'], 'UsageID', 'Usage ID not set in the session') self.assertEqual(session['usage_key'], USAGE_KEY, 'Usage key not set in the session')
for key in views.REQUIRED_PARAMETERS: for key in views.REQUIRED_PARAMETERS:
self.assertEqual(session[key], request.POST[key], key + ' not set in the session') self.assertEqual(session[key], request.POST[key], key + ' not set in the session')
...@@ -126,7 +128,9 @@ class LtiLaunchTest(TestCase): ...@@ -126,7 +128,9 @@ class LtiLaunchTest(TestCase):
URL URL
""" """
request = build_launch_request(False) request = build_launch_request(False)
response = views.lti_launch(request, None, None) response = views.lti_launch(
request, str(COURSE_PARAMS['course_key']), str(COURSE_PARAMS['usage_key'])
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response['Location'], '/accounts/login?next=/lti_provider/lti_run') self.assertEqual(response['Location'], '/accounts/login?next=/lti_provider/lti_run')
...@@ -170,7 +174,7 @@ class LtiRunTest(TestCase): ...@@ -170,7 +174,7 @@ class LtiRunTest(TestCase):
Verifies that the lti_run view returns a Forbidden status if the session Verifies that the lti_run view returns a Forbidden status if the session
is missing any of the required LTI parameters or course information. is missing any of the required LTI parameters or course information.
""" """
extra_keys = ['course_id', 'usage_id'] extra_keys = ['course_key', 'usage_key']
for key in views.REQUIRED_PARAMETERS + extra_keys: for key in views.REQUIRED_PARAMETERS + extra_keys:
request = build_run_request() request = build_run_request()
del request.session[views.LTI_SESSION_KEY][key] del request.session[views.LTI_SESSION_KEY][key]
...@@ -208,7 +212,6 @@ class RenderCoursewareTest(TestCase): ...@@ -208,7 +212,6 @@ class RenderCoursewareTest(TestCase):
self.module_mock = self.setup_patch('lti_provider.views.get_module_by_usage_id', (self.module_instance, None)) self.module_mock = self.setup_patch('lti_provider.views.get_module_by_usage_id', (self.module_instance, None))
self.access_mock = self.setup_patch('lti_provider.views.has_access', 'StaffAccess') self.access_mock = self.setup_patch('lti_provider.views.has_access', 'StaffAccess')
self.course_mock = self.setup_patch('lti_provider.views.get_course_with_access', 'CourseWithAccess') self.course_mock = self.setup_patch('lti_provider.views.get_course_with_access', 'CourseWithAccess')
self.key_mock = self.setup_patch('lti_provider.views.CourseKey.from_string', 'CourseKey')
def setup_patch(self, function_name, return_value): def setup_patch(self, function_name, return_value):
""" """
...@@ -228,21 +231,13 @@ class RenderCoursewareTest(TestCase): ...@@ -228,21 +231,13 @@ class RenderCoursewareTest(TestCase):
response = views.render_courseware(request, ALL_PARAMS.copy()) response = views.render_courseware(request, ALL_PARAMS.copy())
self.assertEqual(response, 'Rendered page') self.assertEqual(response, 'Rendered page')
def test_course_key(self):
"""
Verify that the correct course key is requested
"""
request = build_run_request()
views.render_courseware(request, ALL_PARAMS.copy())
self.key_mock.assert_called_with(ALL_PARAMS['course_id'])
def test_course_with_access(self): def test_course_with_access(self):
""" """
Verify that get_course_with_access is called with the right parameters Verify that get_course_with_access is called with the right parameters
""" """
request = build_run_request() request = build_run_request()
views.render_courseware(request, ALL_PARAMS.copy()) views.render_courseware(request, ALL_PARAMS.copy())
self.course_mock.assert_called_with(request.user, 'load', 'CourseKey') self.course_mock.assert_called_with(request.user, 'load', COURSE_KEY)
def test_has_access(self): def test_has_access(self):
""" """
...@@ -258,7 +253,7 @@ class RenderCoursewareTest(TestCase): ...@@ -258,7 +253,7 @@ class RenderCoursewareTest(TestCase):
""" """
request = build_run_request() request = build_run_request()
views.render_courseware(request, ALL_PARAMS.copy()) views.render_courseware(request, ALL_PARAMS.copy())
self.module_mock.assert_called_with(request, ALL_PARAMS['course_id'], ALL_PARAMS['usage_id']) self.module_mock.assert_called_with(request, str(ALL_PARAMS['course_key']), str(ALL_PARAMS['usage_key']))
def test_render(self): def test_render(self):
""" """
...@@ -278,7 +273,7 @@ class RenderCoursewareTest(TestCase): ...@@ -278,7 +273,7 @@ class RenderCoursewareTest(TestCase):
'disable_footer': True, 'disable_footer': True,
'disable_tabs': True, 'disable_tabs': True,
'staff_access': 'StaffAccess', 'staff_access': 'StaffAccess',
'xqa_server': 'http://your_xqa_server.com', 'xqa_server': 'http://example.com/xqa',
} }
request = build_run_request() request = build_run_request()
views.render_courseware(request, ALL_PARAMS.copy()) views.render_courseware(request, ALL_PARAMS.copy())
......
...@@ -6,7 +6,7 @@ from django.conf import settings ...@@ -6,7 +6,7 @@ from django.conf import settings
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.contrib.auth.views import redirect_to_login from django.contrib.auth.views import redirect_to_login
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import HttpResponseBadRequest, HttpResponseForbidden from django.http import HttpResponseBadRequest, HttpResponseForbidden, Http404
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from courseware.access import has_access from courseware.access import has_access
...@@ -14,7 +14,9 @@ from courseware.courses import get_course_with_access ...@@ -14,7 +14,9 @@ from courseware.courses import get_course_with_access
from courseware.module_render import get_module_by_usage_id from courseware.module_render import get_module_by_usage_id
from edxmako.shortcuts import render_to_response from edxmako.shortcuts import render_to_response
from lti_provider.signature_validator import SignatureValidator from lti_provider.signature_validator import SignatureValidator
from opaque_keys.edx.keys import CourseKey from lms_xblock.runtime import unquote_slashes
from opaque_keys.edx.keys import CourseKey, UsageKey
from opaque_keys import InvalidKeyError
# LTI launch parameters that must be present for a successful launch # LTI launch parameters that must be present for a successful launch
REQUIRED_PARAMETERS = [ REQUIRED_PARAMETERS = [
...@@ -64,8 +66,11 @@ def lti_launch(request, course_id, usage_id): ...@@ -64,8 +66,11 @@ def lti_launch(request, course_id, usage_id):
# Store the course, and usage ID in the session to prevent privilege # Store the course, and usage ID in the session to prevent privilege
# escalation if a staff member in one course tries to access material in # escalation if a staff member in one course tries to access material in
# another. # another.
params['course_id'] = course_id course_key, usage_key = parse_course_and_usage_keys(course_id, usage_id)
params['usage_id'] = usage_id if not course_key:
raise Http404('Invalid course or usage key')
params['course_key'] = course_key
params['usage_key'] = usage_key
request.session[LTI_SESSION_KEY] = params request.session[LTI_SESSION_KEY] = params
if not request.user.is_authenticated(): if not request.user.is_authenticated():
...@@ -140,7 +145,7 @@ def restore_params_from_session(request): ...@@ -140,7 +145,7 @@ def restore_params_from_session(request):
if LTI_SESSION_KEY not in request.session: if LTI_SESSION_KEY not in request.session:
return None return None
session_params = request.session[LTI_SESSION_KEY] session_params = request.session[LTI_SESSION_KEY]
additional_params = ['course_id', 'usage_id'] additional_params = ['course_key', 'usage_key']
return get_required_parameters(session_params, additional_params) return get_required_parameters(session_params, additional_params)
...@@ -154,13 +159,12 @@ def render_courseware(request, lti_params): ...@@ -154,13 +159,12 @@ def render_courseware(request, lti_params):
:return: an HttpResponse object that contains the template and necessary :return: an HttpResponse object that contains the template and necessary
context to render the courseware. context to render the courseware.
""" """
usage_id = lti_params['usage_id'] usage_key = lti_params['usage_key']
course_id = lti_params['course_id'] course_key = lti_params['course_key']
course_key = CourseKey.from_string(course_id)
user = request.user user = request.user
course = get_course_with_access(user, 'load', course_key) course = get_course_with_access(user, 'load', course_key)
staff = has_access(request.user, 'staff', course) staff = has_access(request.user, 'staff', course)
instance, _ = get_module_by_usage_id(request, course_id, usage_id) instance, _ = get_module_by_usage_id(request, str(course_key), str(usage_key))
fragment = instance.render('student_view', context=request.GET) fragment = instance.render('student_view', context=request.GET)
...@@ -173,7 +177,26 @@ def render_courseware(request, lti_params): ...@@ -173,7 +177,26 @@ def render_courseware(request, lti_params):
'disable_footer': True, 'disable_footer': True,
'disable_tabs': True, 'disable_tabs': True,
'staff_access': staff, 'staff_access': staff,
'xqa_server': settings.FEATURES.get('XQA_SERVER', 'http://your_xqa_server.com'), 'xqa_server': settings.FEATURES.get('XQA_SERVER', 'http://example.com/xqa'),
} }
return render_to_response('courseware/courseware.html', context) return render_to_response('courseware/courseware.html', context)
def parse_course_and_usage_keys(course_id, usage_id):
"""
Convert course and usage ID strings into key objects. Return a tuple of
(course_key, usage_key), or (None, None) if the translation fails.
"""
try:
course_key = CourseKey.from_string(course_id)
except InvalidKeyError:
return None, None
if not course_key:
return None, None
try:
usage_id = unquote_slashes(usage_id)
usage_key = UsageKey.from_string(usage_id).map_into_course(course_key)
except InvalidKeyError:
return None, None
return course_key, usage_key
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment