Commit 0ea3753d by Albert St. Aubin

Added Entitlement enroll and unenroll logic to the Enrollment API

[LEARNER-3136]

This commit containts the logic and API endpoint for a user to Enroll,
Unenroll, and switch-session on a Course Entitlement.
parent fef38170
import json import json
import logging
import unittest import unittest
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
...@@ -6,11 +7,16 @@ from datetime import datetime, timedelta ...@@ -6,11 +7,16 @@ from datetime import datetime, timedelta
import pytz import pytz
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from mock import patch
from student.tests.factories import (TEST_PASSWORD, CourseEnrollmentFactory, UserFactory) from opaque_keys.edx.locator import CourseKey
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
from student.models import CourseEnrollment
from student.tests.factories import (TEST_PASSWORD, CourseEnrollmentFactory, UserFactory)
log = logging.getLogger(__name__)
# Entitlements is not in CMS' INSTALLED_APPS so these imports will error during test collection # Entitlements is not in CMS' INSTALLED_APPS so these imports will error during test collection
if settings.ROOT_URLCONF == 'lms.urls': if settings.ROOT_URLCONF == 'lms.urls':
from entitlements.tests.factories import CourseEntitlementFactory from entitlements.tests.factories import CourseEntitlementFactory
...@@ -81,7 +87,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ...@@ -81,7 +87,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase):
not_staff_user = UserFactory() not_staff_user = UserFactory()
self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) self.client.login(username=not_staff_user.username, password=TEST_PASSWORD)
course_entitlement = CourseEntitlementFactory() course_entitlement = CourseEntitlementFactory.create()
url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(course_entitlement.uuid)]) url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(course_entitlement.uuid)])
response = self.client.delete( response = self.client.delete(
...@@ -122,7 +128,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ...@@ -122,7 +128,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase):
results = response.data.get('results', []) # pylint: disable=no-member results = response.data.get('results', []) # pylint: disable=no-member
assert results == CourseEntitlementSerializer([entitlement], many=True).data assert results == CourseEntitlementSerializer([entitlement], many=True).data
def test_staff_not_get_all_entitlements(self): def test_staff_get_only_staff_entitlements(self):
CourseEntitlementFactory.create_batch(2) CourseEntitlementFactory.create_batch(2)
entitlement = CourseEntitlementFactory.create(user=self.user) entitlement = CourseEntitlementFactory.create(user=self.user)
...@@ -189,7 +195,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ...@@ -189,7 +195,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase):
assert results == CourseEntitlementSerializer([entitlement_user2], many=True).data assert results == CourseEntitlementSerializer([entitlement_user2], many=True).data
def test_get_entitlement_by_uuid(self): def test_get_entitlement_by_uuid(self):
entitlement = CourseEntitlementFactory() entitlement = CourseEntitlementFactory.create()
CourseEntitlementFactory.create_batch(2) CourseEntitlementFactory.create_batch(2)
url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(entitlement.uuid)]) url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(entitlement.uuid)])
...@@ -253,3 +259,173 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ...@@ -253,3 +259,173 @@ class EntitlementViewSetTest(ModuleStoreTestCase):
course_entitlement.refresh_from_db() course_entitlement.refresh_from_db()
assert course_entitlement.expired_at is not None assert course_entitlement.expired_at is not None
assert course_entitlement.enrollment_course_run is None assert course_entitlement.enrollment_course_run is None
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EntitlementEnrollmentViewSetTest(ModuleStoreTestCase):
"""
Tests for the EntitlementEnrollmentViewSets
"""
ENTITLEMENTS_ENROLLMENT_NAMESPACE = 'entitlements_api:v1:enrollments'
def setUp(self):
super(EntitlementEnrollmentViewSetTest, self).setUp()
self.user = UserFactory()
self.client.login(username=self.user.username, password=TEST_PASSWORD)
self.course = CourseFactory.create(org='edX', number='DemoX', display_name='Demo_Course')
self.course2 = CourseFactory.create(org='edX', number='DemoX2', display_name='Demo_Course 2')
self.return_values = [
{'key': str(self.course.id)},
{'key': str(self.course2.id)}
]
@patch("entitlements.api.v1.views.get_course_runs_for_course")
def test_user_can_enroll(self, mock_get_course_runs):
course_entitlement = CourseEntitlementFactory.create(user=self.user)
mock_get_course_runs.return_value = self.return_values
url = reverse(
self.ENTITLEMENTS_ENROLLMENT_NAMESPACE,
args=[str(course_entitlement.uuid)]
)
assert course_entitlement.enrollment_course_run is None
data = {
'course_run_id': str(self.course.id)
}
response = self.client.post(
url,
data=json.dumps(data),
content_type='application/json',
)
course_entitlement.refresh_from_db()
assert response.status_code == 201
assert CourseEnrollment.is_enrolled(self.user, self.course.id)
assert course_entitlement.enrollment_course_run is not None
@patch("entitlements.api.v1.views.get_course_runs_for_course")
def test_user_can_unenroll(self, mock_get_course_runs):
course_entitlement = CourseEntitlementFactory.create(user=self.user)
mock_get_course_runs.return_value = self.return_values
url = reverse(
self.ENTITLEMENTS_ENROLLMENT_NAMESPACE,
args=[str(course_entitlement.uuid)]
)
assert course_entitlement.enrollment_course_run is None
data = {
'course_run_id': str(self.course.id)
}
response = self.client.post(
url,
data=json.dumps(data),
content_type='application/json',
)
course_entitlement.refresh_from_db()
assert response.status_code == 201
assert CourseEnrollment.is_enrolled(self.user, self.course.id)
response = self.client.delete(
url,
content_type='application/json',
)
assert response.status_code == 204
course_entitlement.refresh_from_db()
assert not CourseEnrollment.is_enrolled(self.user, self.course.id)
assert course_entitlement.enrollment_course_run is None
@patch("entitlements.api.v1.views.get_course_runs_for_course")
def test_user_can_switch(self, mock_get_course_runs):
mock_get_course_runs.return_value = self.return_values
course_entitlement = CourseEntitlementFactory.create(user=self.user)
url = reverse(
self.ENTITLEMENTS_ENROLLMENT_NAMESPACE,
args=[str(course_entitlement.uuid)]
)
assert course_entitlement.enrollment_course_run is None
data = {
'course_run_id': str(self.course.id)
}
response = self.client.post(
url,
data=json.dumps(data),
content_type='application/json',
)
course_entitlement.refresh_from_db()
assert response.status_code == 201
assert CourseEnrollment.is_enrolled(self.user, self.course.id)
data = {
'course_run_id': str(self.course2.id)
}
response = self.client.post(
url,
data=json.dumps(data),
content_type='application/json',
)
assert response.status_code == 201
course_entitlement.refresh_from_db()
assert CourseEnrollment.is_enrolled(self.user, self.course2.id)
assert course_entitlement.enrollment_course_run is not None
@patch("entitlements.api.v1.views.get_course_runs_for_course")
def test_user_already_enrolled(self, mock_get_course_runs):
course_entitlement = CourseEntitlementFactory.create(user=self.user)
mock_get_course_runs.return_value = self.return_values
url = reverse(
self.ENTITLEMENTS_ENROLLMENT_NAMESPACE,
args=[str(course_entitlement.uuid)]
)
CourseEnrollment.enroll(self.user, self.course.id, mode=course_entitlement.mode)
data = {
'course_run_id': str(self.course.id)
}
response = self.client.post(
url,
data=json.dumps(data),
content_type='application/json',
)
course_entitlement.refresh_from_db()
assert response.status_code == 201
assert CourseEnrollment.is_enrolled(self.user, self.course.id)
course_entitlement.refresh_from_db()
assert CourseEnrollment.is_enrolled(self.user, self.course.id)
assert course_entitlement.enrollment_course_run is not None
@patch("entitlements.api.v1.views.get_course_runs_for_course")
def test_user_cannot_enroll_in_unknown_course_run_id(self, mock_get_course_runs):
fake_course_str = str(self.course.id) + 'fake'
fake_course_key = CourseKey.from_string(fake_course_str)
course_entitlement = CourseEntitlementFactory.create(user=self.user)
mock_get_course_runs.return_value = self.return_values
url = reverse(
self.ENTITLEMENTS_ENROLLMENT_NAMESPACE,
args=[str(course_entitlement.uuid)]
)
data = {
'course_run_id': str(fake_course_key)
}
response = self.client.post(
url,
data=json.dumps(data),
content_type='application/json',
)
expected_message = 'The Course Run ID is not a match for this Course Entitlement.'
assert response.status_code == 400
assert response.data['message'] == expected_message # pylint: disable=no-member
assert not CourseEnrollment.is_enrolled(self.user, fake_course_key)
from django.conf.urls import url, include from django.conf.urls import url, include
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from .views import EntitlementViewSet from .views import EntitlementViewSet, EntitlementEnrollmentViewSet
router = DefaultRouter() router = DefaultRouter()
router.register(r'entitlements', EntitlementViewSet, base_name='entitlements') router.register(r'entitlements', EntitlementViewSet, base_name='entitlements')
ENROLLMENTS_VIEW = EntitlementEnrollmentViewSet.as_view({
'post': 'create',
'delete': 'destroy',
})
urlpatterns = [ urlpatterns = [
url(r'', include(router.urls)), url(r'', include(router.urls)),
url(
r'entitlements/(?P<uuid>{regex})/enrollments$'.format(regex=EntitlementViewSet.ENTITLEMENT_UUID4_REGEX),
ENROLLMENTS_VIEW,
name='enrollments'
)
] ]
...@@ -4,23 +4,30 @@ from django.db import transaction ...@@ -4,23 +4,30 @@ from django.db import transaction
from django.utils import timezone from django.utils import timezone
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from edx_rest_framework_extensions.authentication import JwtAuthentication from edx_rest_framework_extensions.authentication import JwtAuthentication
from rest_framework import permissions, viewsets from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
from rest_framework import permissions, viewsets, status
from rest_framework.authentication import SessionAuthentication
from rest_framework.response import Response from rest_framework.response import Response
from entitlements.api.v1.filters import CourseEntitlementFilter from entitlements.api.v1.filters import CourseEntitlementFilter
from entitlements.api.v1.permissions import IsAdminOrAuthenticatedReadOnly from entitlements.api.v1.permissions import IsAdminOrAuthenticatedReadOnly
from entitlements.api.v1.serializers import CourseEntitlementSerializer from entitlements.api.v1.serializers import CourseEntitlementSerializer
from entitlements.models import CourseEntitlement from entitlements.models import CourseEntitlement
from openedx.core.djangoapps.catalog.utils import get_course_runs_for_course
from openedx.core.djangoapps.cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf from openedx.core.djangoapps.cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
from student.models import CourseEnrollment from student.models import CourseEnrollment
from student.models import CourseEnrollmentException, AlreadyEnrolledError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class EntitlementViewSet(viewsets.ModelViewSet): class EntitlementViewSet(viewsets.ModelViewSet):
ENTITLEMENT_UUID4_REGEX = '[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}'
authentication_classes = (JwtAuthentication, SessionAuthenticationCrossDomainCsrf,) authentication_classes = (JwtAuthentication, SessionAuthenticationCrossDomainCsrf,)
permission_classes = (permissions.IsAuthenticated, IsAdminOrAuthenticatedReadOnly,) permission_classes = (permissions.IsAuthenticated, IsAdminOrAuthenticatedReadOnly,)
lookup_value_regex = '[0-9a-f-]+' lookup_value_regex = ENTITLEMENT_UUID4_REGEX
lookup_field = 'uuid' lookup_field = 'uuid'
serializer_class = CourseEntitlementSerializer serializer_class = CourseEntitlementSerializer
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend,)
...@@ -102,3 +109,169 @@ class EntitlementViewSet(viewsets.ModelViewSet): ...@@ -102,3 +109,169 @@ class EntitlementViewSet(viewsets.ModelViewSet):
) )
if save_model: if save_model:
instance.save() instance.save()
class EntitlementEnrollmentViewSet(viewsets.GenericViewSet):
"""
Endpoint in the Entitlement API to handle the Enrollment of a User's Entitlement.
This API will handle
- Enroll
- Unenroll
- Switch Enrollment
"""
authentication_classes = (JwtAuthentication, SessionAuthentication,)
permission_classes = (permissions.IsAuthenticated,)
queryset = CourseEntitlement.objects.all()
def _verify_course_run_for_entitlement(self, entitlement, course_run_id):
"""
Verifies that a Course run is a child of the Course assigned to the entitlement.
"""
course_runs = get_course_runs_for_course(entitlement.course_uuid)
for run in course_runs:
if course_run_id == run.get('key', ''):
return True
return False
def _enroll_entitlement(self, entitlement, course_run_key, user):
"""
Internal method to handle the details of enrolling a User in a Course Run.
Returns a response object is there is an error or exception, None otherwise
"""
try:
enrollment = CourseEnrollment.enroll(
user=user,
course_key=course_run_key,
mode=entitlement.mode,
check_access=True
)
except AlreadyEnrolledError:
enrollment = CourseEnrollment.get_enrollment(user, course_run_key)
if enrollment.mode == entitlement.mode:
CourseEntitlement.set_enrollment(entitlement, enrollment)
# Else the User is already enrolled in another Mode and we should
# not do anything else related to Entitlements.
except CourseEnrollmentException:
message = (
'Course Entitlement Enroll for {username} failed for course: {course_id}, '
'mode: {mode}, and entitlement: {entitlement}'
).format(
username=user.username,
course_id=course_run_key,
mode=entitlement.mode,
entitlement=entitlement.uuid
)
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={'message': message}
)
CourseEntitlement.set_enrollment(entitlement, enrollment)
return None
def _unenroll_entitlement(self, entitlement, course_run_key, user):
"""
Internal method to handle the details of Unenrolling a User in a Course Run.
"""
CourseEnrollment.unenroll(user, course_run_key, skip_refund=True)
CourseEntitlement.set_enrollment(entitlement, None)
def create(self, request, uuid):
"""
On POST this method will be called and will handle enrolling a user in the
provided course_run_id from the data. This is called on a specific entitlement
UUID so the course_run_id has to correspond to the Course that is assigned to
the Entitlement.
When this API is called for a user who is already enrolled in a run that User
will be unenrolled from their current run and enrolled in the new run if it is
available.
"""
course_run_id = request.data.get('course_run_id', None)
if not course_run_id:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data='The Course Run ID was not provided.'
)
# Verify that the user has an Entitlement for the provided Course UUID.
try:
entitlement = CourseEntitlement.objects.get(uuid=uuid, user=request.user, expired_at=None)
except CourseEntitlement.DoesNotExist:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data='The Entitlement for this UUID does not exist or is Expired.'
)
# Verify the course run ID is of the same type as the Course entitlement.
course_run_valid = self._verify_course_run_for_entitlement(entitlement, course_run_id)
if not course_run_valid:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={
'message': 'The Course Run ID is not a match for this Course Entitlement.'
}
)
# Determine if this is a Switch session or a simple enroll and handle both.
try:
course_run_string = CourseKey.from_string(course_run_id)
except InvalidKeyError:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={
'message': 'Invalid {course_id}'.format(course_id=course_run_id)
}
)
if entitlement.enrollment_course_run is None:
response = self._enroll_entitlement(
entitlement=entitlement,
course_run_key=course_run_string,
user=request.user
)
if response:
return response
elif entitlement.enrollment_course_run.course_id != course_run_id:
self._unenroll_entitlement(
entitlement=entitlement,
course_run_key=entitlement.enrollment_course_run.course_id,
user=request.user
)
response = self._enroll_entitlement(
entitlement=entitlement,
course_run_key=course_run_string,
user=request.user
)
if response:
return response
return Response(
status=status.HTTP_201_CREATED,
data={
'course_run_id': course_run_id,
}
)
def destroy(self, request, uuid):
"""
On DELETE call to this API we will unenroll the course enrollment for the provided uuid
"""
try:
entitlement = CourseEntitlement.objects.get(uuid=uuid, user=request.user, expired_at=None)
except CourseEntitlement.DoesNotExist:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data='The Entitlement for this UUID does not exist or is Expired.'
)
if entitlement.enrollment_course_run is None:
return Response(status=status.HTTP_204_NO_CONTENT)
self._unenroll_entitlement(
entitlement=entitlement,
course_run_key=entitlement.enrollment_course_run.course_id,
user=request.user
)
return Response(status=status.HTTP_204_NO_CONTENT)
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
('entitlements', '0003_auto_20171205_1431'),
]
operations = [
migrations.AlterField(
model_name='courseentitlement',
name='uuid',
field=models.UUIDField(default=uuid.uuid4, unique=True, editable=False),
),
]
...@@ -125,7 +125,7 @@ class CourseEntitlement(TimeStampedModel): ...@@ -125,7 +125,7 @@ class CourseEntitlement(TimeStampedModel):
""" """
user = models.ForeignKey(settings.AUTH_USER_MODEL) user = models.ForeignKey(settings.AUTH_USER_MODEL)
uuid = models.UUIDField(default=uuid_tools.uuid4, editable=False) uuid = models.UUIDField(default=uuid_tools.uuid4, editable=False, unique=True)
course_uuid = models.UUIDField(help_text='UUID for the Course, not the Course Run') course_uuid = models.UUIDField(help_text='UUID for the Course, not the Course Run')
expired_at = models.DateTimeField( expired_at = models.DateTimeField(
null=True, null=True,
...@@ -212,3 +212,10 @@ class CourseEntitlement(TimeStampedModel): ...@@ -212,3 +212,10 @@ class CourseEntitlement(TimeStampedModel):
Returns a boolean as to whether or not the entitlement can be redeemed based on the entitlement's policy Returns a boolean as to whether or not the entitlement can be redeemed based on the entitlement's policy
""" """
return self.policy.is_entitlement_redeemable(self) return self.policy.is_entitlement_redeemable(self)
@classmethod
def set_enrollment(cls, entitlement, enrollment):
"""
Fulfills an entitlement by specifying a session.
"""
cls.objects.filter(id=entitlement.id).update(enrollment_course_run=enrollment)
...@@ -235,7 +235,6 @@ def get_course_runs_for_course(course_uuid): ...@@ -235,7 +235,6 @@ def get_course_runs_for_course(course_uuid):
cache_key=cache_key if catalog_integration.is_cache_enabled else None, cache_key=cache_key if catalog_integration.is_cache_enabled else None,
long_term_cache=True long_term_cache=True
) )
return data.get('course_runs', []) return data.get('course_runs', [])
else: else:
return [] return []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment