Commit 1e4f1b58 by Diana Huang

Allow the enrollment API to be accessed via API keys.

XCOM-107
parent 28a3e9f5
...@@ -16,11 +16,13 @@ from util.testing import UrlResetMixin ...@@ -16,11 +16,13 @@ from util.testing import UrlResetMixin
from enrollment import api from enrollment import api
from enrollment.errors import CourseEnrollmentError from enrollment.errors import CourseEnrollmentError
from openedx.core.djangoapps.user_api.models import UserOrgTag from openedx.core.djangoapps.user_api.models import UserOrgTag
from django.test.utils import override_settings
from student.tests.factories import UserFactory, CourseModeFactory from student.tests.factories import UserFactory, CourseModeFactory
from student.models import CourseEnrollment from student.models import CourseEnrollment
from embargo.test_utils import restrict_course from embargo.test_utils import restrict_course
@override_settings(EDX_API_KEY="i am a key")
@ddt.ddt @ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EnrollmentTest(ModuleStoreTestCase, APITestCase): class EnrollmentTest(ModuleStoreTestCase, APITestCase):
...@@ -30,12 +32,14 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ...@@ -30,12 +32,14 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase):
USERNAME = "Bob" USERNAME = "Bob"
EMAIL = "bob@example.com" EMAIL = "bob@example.com"
PASSWORD = "edx" PASSWORD = "edx"
API_KEY = "i am a key"
def setUp(self): def setUp(self):
""" Create a course and user, then log in. """ """ Create a course and user, then log in. """
super(EnrollmentTest, self).setUp() super(EnrollmentTest, self).setUp()
self.course = CourseFactory.create() self.course = CourseFactory.create()
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
self.other_user = UserFactory.create()
self.client.login(username=self.USERNAME, password=self.PASSWORD) self.client.login(username=self.USERNAME, password=self.PASSWORD)
@ddt.data( @ddt.data(
...@@ -179,7 +183,10 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ...@@ -179,7 +183,10 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase):
mode_slug='honor', mode_slug='honor',
mode_display_name='Honor', mode_display_name='Honor',
) )
self._create_enrollment(username='not_the_user', expected_status=status.HTTP_404_NOT_FOUND) self._create_enrollment(username=self.other_user.username, expected_status=status.HTTP_404_NOT_FOUND)
# Verify that the server still has access to this endpoint.
self.client.logout()
self._create_enrollment(username=self.other_user.username, as_server=True)
def test_user_does_not_match_param_for_list(self): def test_user_does_not_match_param_for_list(self):
CourseModeFactory.create( CourseModeFactory.create(
...@@ -187,8 +194,14 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ...@@ -187,8 +194,14 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase):
mode_slug='honor', mode_slug='honor',
mode_display_name='Honor', mode_display_name='Honor',
) )
resp = self.client.get(reverse('courseenrollments'), {"user": "not_the_user"}) resp = self.client.get(reverse('courseenrollments'), {"user": self.other_user.username})
self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND)
# Verify that the server still has access to this endpoint.
self.client.logout()
resp = self.client.get(
reverse('courseenrollments'), {"user": self.other_user.username}, **{'HTTP_X_EDX_API_KEY': self.API_KEY}
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
def test_user_does_not_match_param(self): def test_user_does_not_match_param(self):
CourseModeFactory.create( CourseModeFactory.create(
...@@ -197,9 +210,16 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ...@@ -197,9 +210,16 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase):
mode_display_name='Honor', mode_display_name='Honor',
) )
resp = self.client.get( resp = self.client.get(
reverse('courseenrollment', kwargs={"user": "not_the_user", "course_id": unicode(self.course.id)}) reverse('courseenrollment', kwargs={"user": self.other_user.username, "course_id": unicode(self.course.id)})
) )
# Verify that the server still has access to this endpoint.
self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND)
self.client.logout()
resp = self.client.get(
reverse('courseenrollment', kwargs={"user": self.other_user.username, "course_id": unicode(self.course.id)}),
**{'HTTP_X_EDX_API_KEY': self.API_KEY}
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
def test_get_course_details(self): def test_get_course_details(self):
CourseModeFactory.create( CourseModeFactory.create(
...@@ -237,7 +257,26 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ...@@ -237,7 +257,26 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase):
) )
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
def _create_enrollment(self, course_id=None, username=None, expected_status=status.HTTP_200_OK, email_opt_in=None): def test_enrollment_already_enrolled(self):
response = self._create_enrollment()
repeat_response = self._create_enrollment()
self.assertEqual(json.loads(response.content), json.loads(repeat_response.content))
def test_get_enrollment_with_invalid_key(self):
resp = self.client.post(
reverse('courseenrollments'),
{
'course_details': {
'course_id': 'invalidcourse'
},
'user': self.user.username
},
format='json'
)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("No course ", resp.content)
def _create_enrollment(self, course_id=None, username=None, expected_status=status.HTTP_200_OK, email_opt_in=None, as_server=False):
"""Enroll in the course and verify the URL we are sent to. """ """Enroll in the course and verify the URL we are sent to. """
course_id = unicode(self.course.id) if course_id is None else course_id course_id = unicode(self.course.id) if course_id is None else course_id
username = self.user.username if username is None else username username = self.user.username if username is None else username
...@@ -250,7 +289,11 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ...@@ -250,7 +289,11 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase):
} }
if email_opt_in is not None: if email_opt_in is not None:
params['email_opt_in'] = email_opt_in params['email_opt_in'] = email_opt_in
if as_server:
resp = self.client.post(reverse('courseenrollments'), params, format='json', **{'HTTP_X_EDX_API_KEY': self.API_KEY})
else:
resp = self.client.post(reverse('courseenrollments'), params, format='json') resp = self.client.post(reverse('courseenrollments'), params, format='json')
self.assertEqual(resp.status_code, expected_status) self.assertEqual(resp.status_code, expected_status)
if expected_status == status.HTTP_200_OK: if expected_status == status.HTTP_200_OK:
...@@ -260,25 +303,6 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase): ...@@ -260,25 +303,6 @@ class EnrollmentTest(ModuleStoreTestCase, APITestCase):
self.assertTrue(data['is_active']) self.assertTrue(data['is_active'])
return resp return resp
def test_enrollment_already_enrolled(self):
response = self._create_enrollment()
repeat_response = self._create_enrollment()
self.assertEqual(json.loads(response.content), json.loads(repeat_response.content))
def test_get_enrollment_with_invalid_key(self):
resp = self.client.post(
reverse('courseenrollments'),
{
'course_details': {
'course_id': 'invalidcourse'
},
'user': self.user.username
},
format='json'
)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("No course ", resp.content)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase): class EnrollmentEmbargoTest(UrlResetMixin, ModuleStoreTestCase):
......
...@@ -8,6 +8,7 @@ from django.conf import settings ...@@ -8,6 +8,7 @@ from django.conf import settings
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
from openedx.core.djangoapps.user_api import api as user_api from openedx.core.djangoapps.user_api import api as user_api
from openedx.core.lib.api.permissions import ApiKeyHeaderPermission, ApiKeyHeaderPermissionIsAuthenticated
from rest_framework import status from rest_framework import status
from rest_framework import permissions from rest_framework import permissions
from rest_framework.response import Response from rest_framework.response import Response
...@@ -30,8 +31,27 @@ class EnrollmentUserThrottle(UserRateThrottle): ...@@ -30,8 +31,27 @@ class EnrollmentUserThrottle(UserRateThrottle):
rate = '50/second' rate = '50/second'
class ApiKeyPermissionMixIn(object):
"""
This mixin is used to provide a convenience function for doing individual permission checks
for the presence of API keys.
"""
def has_api_key_permissions(self, request):
"""
Checks to see if the request was made by a server with an API key.
Args:
request (Request): the request being made into the view
Return:
True if the request has been made with a valid API key
False otherwise
"""
return ApiKeyHeaderPermission().has_permission(request, self)
@can_disable_rate_limit @can_disable_rate_limit
class EnrollmentView(APIView): class EnrollmentView(APIView, ApiKeyPermissionMixIn):
""" """
**Use Cases** **Use Cases**
...@@ -73,7 +93,7 @@ class EnrollmentView(APIView): ...@@ -73,7 +93,7 @@ class EnrollmentView(APIView):
""" """
authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser
permission_classes = permissions.IsAuthenticated, permission_classes = ApiKeyHeaderPermissionIsAuthenticated,
throttle_classes = EnrollmentUserThrottle, throttle_classes = EnrollmentUserThrottle,
def get(self, request, course_id=None, user=None): def get(self, request, course_id=None, user=None):
...@@ -94,7 +114,7 @@ class EnrollmentView(APIView): ...@@ -94,7 +114,7 @@ class EnrollmentView(APIView):
""" """
user = user if user else request.user.username user = user if user else request.user.username
if request.user.username != user: if request.user.username != user and not self.has_api_key_permissions(request):
# Return a 404 instead of a 403 (Unauthorized). If one user is looking up # Return a 404 instead of a 403 (Unauthorized). If one user is looking up
# other users, do not let them deduce the existence of an enrollment. # other users, do not let them deduce the existence of an enrollment.
return Response(status=status.HTTP_404_NOT_FOUND) return Response(status=status.HTTP_404_NOT_FOUND)
...@@ -182,7 +202,7 @@ class EnrollmentCourseDetailView(APIView): ...@@ -182,7 +202,7 @@ class EnrollmentCourseDetailView(APIView):
@can_disable_rate_limit @can_disable_rate_limit
class EnrollmentListView(APIView): class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
""" """
**Use Cases** **Use Cases**
...@@ -245,7 +265,7 @@ class EnrollmentListView(APIView): ...@@ -245,7 +265,7 @@ class EnrollmentListView(APIView):
""" """
authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser authentication_classes = OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser
permission_classes = permissions.IsAuthenticated, permission_classes = ApiKeyHeaderPermissionIsAuthenticated,
throttle_classes = EnrollmentUserThrottle, throttle_classes = EnrollmentUserThrottle,
def get(self, request): def get(self, request):
...@@ -253,7 +273,7 @@ class EnrollmentListView(APIView): ...@@ -253,7 +273,7 @@ class EnrollmentListView(APIView):
Gets a list of all course enrollments for the currently logged in user. Gets a list of all course enrollments for the currently logged in user.
""" """
user = request.GET.get('user', request.user.username) user = request.GET.get('user', request.user.username)
if request.user.username != user: if request.user.username != user and not self.has_api_key_permissions(request):
# Return a 404 instead of a 403 (Unauthorized). If one user is looking up # Return a 404 instead of a 403 (Unauthorized). If one user is looking up
# other users, do not let them deduce the existence of an enrollment. # other users, do not let them deduce the existence of an enrollment.
return Response(status=status.HTTP_404_NOT_FOUND) return Response(status=status.HTTP_404_NOT_FOUND)
...@@ -276,7 +296,7 @@ class EnrollmentListView(APIView): ...@@ -276,7 +296,7 @@ class EnrollmentListView(APIView):
user = request.DATA.get('user', request.user.username) user = request.DATA.get('user', request.user.username)
if not user: if not user:
user = request.user.username user = request.user.username
if user != request.user.username: if user != request.user.username and not self.has_api_key_permissions(request):
# Return a 404 instead of a 403 (Unauthorized). If one user is looking up # Return a 404 instead of a 403 (Unauthorized). If one user is looking up
# other users, do not let them deduce the existence of an enrollment. # other users, do not let them deduce the existence of an enrollment.
return Response(status=status.HTTP_404_NOT_FOUND) return Response(status=status.HTTP_404_NOT_FOUND)
......
...@@ -21,6 +21,19 @@ class ApiKeyHeaderPermission(permissions.BasePermission): ...@@ -21,6 +21,19 @@ class ApiKeyHeaderPermission(permissions.BasePermission):
) )
class ApiKeyHeaderPermissionIsAuthenticated(ApiKeyHeaderPermission, permissions.IsAuthenticated):
"""
Allow someone to access the view if they have the API key OR they are authenticated.
See ApiKeyHeaderPermission for more information how the API key portion is implemented.
"""
def has_permission(self, request, view):
#TODO We can optimize this later on when we know which of these methods is used more often.
api_permissions = ApiKeyHeaderPermission.has_permission(self, request, view)
is_authenticated_permissions = permissions.IsAuthenticated.has_permission(self, request, view)
return api_permissions or is_authenticated_permissions
class IsAuthenticatedOrDebug(permissions.BasePermission): class IsAuthenticatedOrDebug(permissions.BasePermission):
""" """
Allows access only to authenticated users, or anyone if debug mode is enabled. Allows access only to authenticated users, or anyone if debug mode is enabled.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment