Commit a27865ab by David Ormsbee

Merge pull request #8647 from mcgachey/lti-view-refactoring

[LTI Provider] Refactoring to remove the lti_run method
parents 43ba78ba 99fbf4d6
......@@ -5,8 +5,6 @@ Subclass of oauthlib's RequestValidator that checks an OAuth signature.
from oauthlib.oauth1 import SignatureOnlyEndpoint
from oauthlib.oauth1 import RequestValidator
from lti_provider.models import LtiConsumer
class SignatureValidator(RequestValidator):
"""
......@@ -18,9 +16,10 @@ class SignatureValidator(RequestValidator):
application-specific requirements.
"""
def __init__(self):
def __init__(self, lti_consumer):
super(SignatureValidator, self).__init__()
self.endpoint = SignatureOnlyEndpoint(self)
self.lti_consumer = lti_consumer
# The OAuth signature uses the endpoint URL as part of the request to be
# hashed. By default, the oauthlib library rejects any URLs that do not
......@@ -77,7 +76,7 @@ class SignatureValidator(RequestValidator):
:return: True if the key is valid, False if it is not.
"""
return LtiConsumer.objects.filter(consumer_key=client_key).count() == 1
return self.lti_consumer.consumer_key == client_key
def get_client_secret(self, client_key, request):
"""
......@@ -87,10 +86,7 @@ class SignatureValidator(RequestValidator):
:return: the client secret that corresponds to the supplied key if
present, or None if the key does not exist in the database.
"""
try:
return LtiConsumer.objects.get(consumer_key=client_key).consumer_secret
except LtiConsumer.DoesNotExist:
return None
return self.lti_consumer.consumer_secret
def verify(self, request):
"""
......
......@@ -2,6 +2,7 @@
Tests for the SignatureValidator class.
"""
import ddt
from django.test import TestCase
from django.test.client import RequestFactory
from mock import patch
......@@ -10,100 +11,95 @@ from lti_provider.models import LtiConsumer
from lti_provider.signature_validator import SignatureValidator
class SignatureValidatorTest(TestCase):
def get_lti_consumer():
"""
Tests for the custom SignatureValidator class that uses the oauthlib library
to check message signatures. Note that these tests mock out the library
itself, since we assume it to be correct.
Helper method for all Signature Validator tests to get an LtiConsumer object.
"""
return LtiConsumer(
consumer_name='Consumer Name',
consumer_key='Consumer Key',
consumer_secret='Consumer Secret'
)
@ddt.ddt
class ClientKeyValidatorTest(TestCase):
"""
Tests for the check_client_key method in the SignatureValidator class.
"""
def setUp(self):
super(ClientKeyValidatorTest, self).setUp()
self.lti_consumer = get_lti_consumer()
def test_valid_client_key(self):
"""
Verify that check_client_key succeeds with a valid key
"""
key = 'valid_key'
self.assertTrue(SignatureValidator().check_client_key(key))
key = self.lti_consumer.consumer_key
self.assertTrue(SignatureValidator(self.lti_consumer).check_client_key(key))
def test_long_client_key(self):
@ddt.data(
('0123456789012345678901234567890123456789',),
('',),
(None,),
)
@ddt.unpack
def test_invalid_client_key(self, key):
"""
Verify that check_client_key fails with a key that is too long
Verify that check_client_key fails with a disallowed key
"""
key = '0123456789012345678901234567890123456789'
self.assertFalse(SignatureValidator().check_client_key(key))
self.assertFalse(SignatureValidator(self.lti_consumer).check_client_key(key))
def test_empty_client_key(self):
"""
Verify that check_client_key fails with a key that is an empty string
"""
key = ''
self.assertFalse(SignatureValidator().check_client_key(key))
def test_null_client_key(self):
"""
Verify that check_client_key fails with a key that is None
"""
key = None
self.assertFalse(SignatureValidator().check_client_key(key))
@ddt.ddt
class NonceValidatorTest(TestCase):
"""
Tests for the check_nonce method in the SignatureValidator class.
"""
def setUp(self):
super(NonceValidatorTest, self).setUp()
self.lti_consumer = get_lti_consumer()
def test_valid_nonce(self):
"""
Verify that check_nonce succeeds with a key of maximum length
"""
nonce = '0123456789012345678901234567890123456789012345678901234567890123'
self.assertTrue(SignatureValidator().check_nonce(nonce))
self.assertTrue(SignatureValidator(self.lti_consumer).check_nonce(nonce))
def test_long_nonce(self):
@ddt.data(
('01234567890123456789012345678901234567890123456789012345678901234',),
('',),
(None,),
)
@ddt.unpack
def test_invalid_nonce(self, nonce):
"""
Verify that check_nonce fails with a key that is too long
Verify that check_nonce fails with badly formatted nonce
"""
nonce = '01234567890123456789012345678901234567890123456789012345678901234'
self.assertFalse(SignatureValidator().check_nonce(nonce))
self.assertFalse(SignatureValidator(self.lti_consumer).check_nonce(nonce))
def test_empty_nonce(self):
"""
Verify that check_nonce fails with a key that is an empty string
"""
nonce = ''
self.assertFalse(SignatureValidator().check_nonce(nonce))
def test_null_nonce(self):
"""
Verify that check_nonce fails with a key that is None
"""
nonce = None
self.assertFalse(SignatureValidator().check_nonce(nonce))
def test_validate_existing_key(self):
"""
Verify that validate_client_key succeeds if the client key exists in the
database
"""
LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret')
self.assertTrue(SignatureValidator().validate_client_key('client_key', None))
def test_validate_missing_key(self):
"""
Verify that validate_client_key fails if the client key is not in the
database
"""
self.assertFalse(SignatureValidator().validate_client_key('client_key', None))
class SignatureValidatorTest(TestCase):
"""
Tests for the custom SignatureValidator class that uses the oauthlib library
to check message signatures. Note that these tests mock out the library
itself, since we assume it to be correct.
"""
def setUp(self):
super(SignatureValidatorTest, self).setUp()
self.lti_consumer = get_lti_consumer()
def test_get_existing_client_secret(self):
"""
Verify that get_client_secret returns the right value if the key is in
the database
"""
LtiConsumer.objects.create(consumer_key='client_key', consumer_secret='client_secret')
secret = SignatureValidator().get_client_secret('client_key', None)
self.assertEqual(secret, 'client_secret')
def test_get_missing_client_secret(self):
"""
Verify that get_client_secret returns None if the key is not in the
database
Verify that get_client_secret returns the right value for the correct
key
"""
secret = SignatureValidator().get_client_secret('client_key', None)
self.assertIsNone(secret)
key = self.lti_consumer.consumer_key
secret = SignatureValidator(self.lti_consumer).get_client_secret(key, None)
self.assertEqual(secret, self.lti_consumer.consumer_secret)
@patch('oauthlib.oauth1.SignatureOnlyEndpoint.validate_request',
return_value=(True, None))
......@@ -116,6 +112,6 @@ class SignatureValidatorTest(TestCase):
content_type = 'application/x-www-form-urlencoded'
request = RequestFactory().post('/url', body, content_type=content_type)
headers = {'Content-Type': content_type}
SignatureValidator().verify(request)
SignatureValidator(self.lti_consumer).verify(request)
verify_mock.assert_called_once_with(
request.build_absolute_uri(), 'POST', body, headers)
......@@ -57,17 +57,6 @@ def build_launch_request(authenticated=True):
return request
def build_run_request(authenticated=True):
"""
Helper method to create a new request object
"""
request = RequestFactory().get('/')
request.user = UserFactory.create()
request.user.is_authenticated = MagicMock(return_value=authenticated)
request.session = {views.LTI_SESSION_KEY: ALL_PARAMS.copy()}
return request
class LtiTestMixin(object):
"""
Mixin for LTI tests
......@@ -144,49 +133,6 @@ class LtiLaunchTest(LtiTestMixin, TestCase):
response = views.lti_launch(request, None, None)
self.assertEqual(response.status_code, 403)
@patch('lti_provider.views.lti_run')
@patch('lti_provider.views.authenticate_lti_user')
def test_session_contents_after_launch(self, _authenticate, _run):
"""
Verifies that the LTI parameters and the course and usage IDs are
properly stored in the session
"""
request = build_launch_request()
views.lti_launch(request, unicode(COURSE_KEY), unicode(USAGE_KEY))
session = request.session[views.LTI_SESSION_KEY]
self.assertEqual(session['course_key'], COURSE_KEY, 'Course key not set in the session')
self.assertEqual(session['usage_key'], USAGE_KEY, 'Usage key not set in the session')
for key in views.REQUIRED_PARAMETERS:
self.assertEqual(session[key], request.POST[key], key + ' not set in the session')
@patch('lti_provider.views.lti_run')
@patch('lti_provider.views.authenticate_lti_user')
def test_optional_parameters_in_session(self, _authenticate, _run):
"""
Verifies that the outcome-related optional LTI parameters are properly
stored in the session
"""
request = build_launch_request()
request.POST.update(LTI_OPTIONAL_PARAMS)
views.lti_launch(
request,
unicode(COURSE_PARAMS['course_key']),
unicode(COURSE_PARAMS['usage_key'])
)
session = request.session[views.LTI_SESSION_KEY]
self.assertEqual(
session['lis_result_sourcedid'], u'result sourcedid',
'Result sourcedid not set in the session'
)
self.assertEqual(
session['lis_outcome_service_url'], u'outcome service URL',
'Outcome service URL not set in the session'
)
self.assertEqual(
session['tool_consumer_instance_guid'], u'consumer instance guid',
'Consumer instance GUID not set in the session'
)
def test_forbidden_if_signature_fails(self):
"""
Verifies that the view returns Forbidden if the LTI OAuth signature is
......@@ -198,71 +144,22 @@ class LtiLaunchTest(LtiTestMixin, TestCase):
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 403)
class LtiRunTest(LtiTestMixin, TestCase):
"""
Tests for the lti_run view
"""
@patch('lti_provider.views.render_courseware')
def test_valid_launch(self, render):
"""
Verifies that the view returns OK if called with the correct context
"""
request = build_run_request()
views.lti_run(request)
render.assert_called_with(request, ALL_PARAMS['usage_key'])
def test_forbidden_if_session_key_missing(self):
"""
Verifies that the lti_run view returns a Forbidden status if the session
doesn't have an entry for the LTI parameters.
"""
request = build_run_request()
del request.session[views.LTI_SESSION_KEY]
response = views.lti_run(request)
self.assertEqual(response.status_code, 403)
def test_forbidden_if_session_incomplete(self):
"""
Verifies that the lti_run view returns a Forbidden status if the session
is missing any of the required LTI parameters or course information.
"""
extra_keys = ['course_key', 'usage_key']
for key in views.REQUIRED_PARAMETERS + extra_keys:
request = build_run_request()
del request.session[views.LTI_SESSION_KEY][key]
response = views.lti_run(request)
self.assertEqual(
response.status_code,
403,
'Expected Forbidden response when session is missing ' + key
)
@patch('lti_provider.views.render_courseware')
def test_session_cleared_in_view(self, _render):
"""
Verifies that the LTI parameters are cleaned out of the session after
launching the view to prevent a launch being replayed.
"""
request = build_run_request()
views.lti_run(request)
self.assertNotIn(views.LTI_SESSION_KEY, request.session)
@patch('lti_provider.views.render_courseware')
def test_lti_consumer_record_supplemented_with_guid(self, _render):
request = build_run_request()
request.session[views.LTI_SESSION_KEY]['tool_consumer_instance_guid'] = 'instance_guid'
SignatureValidator.verify = MagicMock(return_value=False)
request = build_launch_request()
request.POST.update(LTI_OPTIONAL_PARAMS)
with self.assertNumQueries(4):
views.lti_run(request)
views.lti_launch(request, None, None)
consumer = models.LtiConsumer.objects.get(
consumer_key=LTI_DEFAULT_PARAMS['oauth_consumer_key']
)
self.assertEqual(consumer.instance_guid, 'instance_guid')
self.assertEqual(consumer.instance_guid, u'consumer instance guid')
class LtiRunTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase):
class LtiLaunchTestRender(LtiTestMixin, RenderXBlockTestMixin, ModuleStoreTestCase):
"""
Tests for the rendering returned by lti_run view.
Tests for the rendering returned by lti_launch view.
This class overrides the get_response method, which is used by
the tests defined in RenderXBlockTestMixin.
"""
......
......@@ -14,5 +14,4 @@ urlpatterns = patterns(
usage_id=settings.USAGE_ID_PATTERN
),
'lti_provider.views.lti_launch', name="lti_provider_launch"),
url(r'^lti_run$', 'lti_provider.views.lti_run', name="lti_provider_run"),
)
......@@ -33,8 +33,6 @@ OPTIONAL_PARAMETERS = [
'tool_consumer_instance_guid'
]
LTI_SESSION_KEY = 'lti_provider_parameters'
@csrf_exempt
def lti_launch(request, course_id, usage_id):
......@@ -48,38 +46,32 @@ def lti_launch(request, course_id, usage_id):
- The launch contains all the required parameters
- The launch data is correctly signed using a known client key/secret
pair
- The user is logged into the edX instance
Authentication in this view is a little tricky, since clients use a POST
with parameters to fetch it. We can't just use @login_required since in the
case where a user is not logged in it will redirect back after login using a
GET request, which would lose all of our LTI parameters.
Instead, we verify the LTI launch in this view before checking if the user
is logged in, and store the required LTI parameters in the session. Then we
do the authentication check, and if login is required we redirect back to
the lti_run view. If the user is already logged in, we just call that view
directly.
"""
if not settings.FEATURES['ENABLE_LTI_PROVIDER']:
return HttpResponseForbidden()
# Check the OAuth signature on the message
try:
if not SignatureValidator().verify(request):
return HttpResponseForbidden()
except LtiConsumer.DoesNotExist:
return HttpResponseForbidden()
# Check the LTI parameters, and return 400 if any required parameters are
# missing
params = get_required_parameters(request.POST)
if not params:
return HttpResponseBadRequest()
params.update(get_optional_parameters(request.POST))
# Store the course, and usage ID in the session to prevent privilege
# escalation if a staff member in one course tries to access material in
# another.
# Get the consumer information from either the instance GUID or the consumer
# key
try:
lti_consumer = LtiConsumer.get_or_supplement(
params.get('tool_consumer_instance_guid', None),
params['oauth_consumer_key']
)
except LtiConsumer.DoesNotExist:
return HttpResponseForbidden()
# Check the OAuth signature on the message
if not SignatureValidator(lti_consumer).verify(request):
return HttpResponseForbidden()
# Add the course and usage keys to the parameters array
try:
course_key, usage_key = parse_course_and_usage_keys(course_id, usage_id)
except InvalidKeyError:
......@@ -93,57 +85,13 @@ def lti_launch(request, course_id, usage_id):
params['course_key'] = course_key
params['usage_key'] = usage_key
try:
lti_consumer = LtiConsumer.get_or_supplement(
params.get('tool_consumer_instance_guid', None),
params['oauth_consumer_key']
)
except LtiConsumer.DoesNotExist:
return HttpResponseForbidden()
# Create an edX account if the user identifed by the LTI launch doesn't have
# one already, and log the edX account into the platform.
authenticate_lti_user(request, params['user_id'], lti_consumer)
request.session[LTI_SESSION_KEY] = params
return lti_run(request)
@login_required
def lti_run(request):
"""
This method can be reached in two ways, and must always follow a POST to
lti_launch:
- The user was logged in, so this method was called by lti_launch
- The user was not logged in, so the login process redirected them back here.
In either case, the session was populated by lti_launch, so all the required
LTI parameters will be stored there. Note that the request passed here may
or may not contain the LTI parameters (depending on how the user got here),
and so we should only use LTI parameters from the session.
Users should never call this view directly; if a user attempts to call it
without having first gone through lti_launch (and had the LTI parameters
stored in the session) they will get a 403 response.
"""
# Check the parameters to make sure that the session is associated with a
# valid LTI launch
params = restore_params_from_session(request)
if not params:
# This view has been called without first setting the session
return HttpResponseForbidden()
# Remove the parameters from the session to prevent replay
del request.session[LTI_SESSION_KEY]
# Store any parameters required by the outcome service in order to report
# scores back later. We know that the consumer exists, since the record was
# used earlier to verify the oauth signature.
lti_consumer = LtiConsumer.get_or_supplement(
params.get('tool_consumer_instance_guid', None),
params['oauth_consumer_key']
)
store_outcome_parameters(params, request.user, lti_consumer)
return render_courseware(request, params['usage_key'])
......@@ -184,26 +132,6 @@ def get_optional_parameters(dictionary):
return {key: dictionary[key] for key in OPTIONAL_PARAMETERS if key in dictionary}
def restore_params_from_session(request):
"""
Fetch the parameters that were stored in the session by an LTI launch, and
verify that all required parameters are present. Missing parameters could
indicate that a user has directly called the lti_run endpoint, rather than
going through the LTI launch.
:return: A dictionary of all LTI parameters from the session, or None if
any parameters are missing.
"""
if LTI_SESSION_KEY not in request.session:
return None
session_params = request.session[LTI_SESSION_KEY]
additional_params = ['course_key', 'usage_key']
for key in REQUIRED_PARAMETERS + additional_params:
if key not in session_params:
return None
return session_params
def render_courseware(request, usage_key):
"""
Render the content requested for the LTI launch.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment