Commit f0ab07a1 by Michael LoTurco

Upgrade audit enrollments on entitlement purchase

Adds check for existing upgradeable enrollment on user entitlement
creation, if single upgradeable enrollment is present, upgrades enrollment
to entitlement mode and links the entitlement to the upgraded enrollment

LEARNER-3579
parent 2c4a5207
...@@ -7,6 +7,9 @@ from datetime import datetime, timedelta ...@@ -7,6 +7,9 @@ from datetime import datetime, timedelta
import pytz import pytz
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from course_modes.models import CourseMode
from course_modes.tests.factories import CourseModeFactory
from mock import patch from mock import patch
from opaque_keys.edx.locator import CourseKey from opaque_keys.edx.locator import CourseKey
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
...@@ -34,6 +37,13 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ...@@ -34,6 +37,13 @@ class EntitlementViewSetTest(ModuleStoreTestCase):
self.user = UserFactory(is_staff=True) self.user = UserFactory(is_staff=True)
self.client.login(username=self.user.username, password=TEST_PASSWORD) self.client.login(username=self.user.username, password=TEST_PASSWORD)
self.course = CourseFactory() self.course = CourseFactory()
self.course_mode = CourseModeFactory(
course_id=self.course.id,
mode_slug=CourseMode.VERIFIED,
# This must be in the future to ensure it is returned by downstream code.
expiration_datetime=datetime.now(pytz.UTC) + timedelta(days=1)
)
self.entitlements_list_url = reverse('entitlements_api:v1:entitlements-list') self.entitlements_list_url = reverse('entitlements_api:v1:entitlements-list')
def _get_data_set(self, user, course_uuid): def _get_data_set(self, user, course_uuid):
...@@ -115,6 +125,37 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ...@@ -115,6 +125,37 @@ class EntitlementViewSetTest(ModuleStoreTestCase):
) )
assert results == CourseEntitlementSerializer(course_entitlement).data assert results == CourseEntitlementSerializer(course_entitlement).data
@patch("entitlements.api.v1.views.get_course_runs_for_course")
def test_add_entitlement_and_upgrade_audit_enrollment(self, mock_get_course_runs):
"""
Verify that if an entitlement is added for a user, if the user has one upgradeable enrollment
that enrollment is upgraded to the mode of the entitlement and linked to the entitlement.
"""
course_uuid = uuid.uuid4()
entitlement_data = self._get_data_set(self.user, str(course_uuid))
mock_get_course_runs.return_value = [{'key': str(self.course.id)}]
# Add an audit course enrollment for user.
enrollment = CourseEnrollment.enroll(self.user, self.course.id, mode=CourseMode.AUDIT)
response = self.client.post(
self.entitlements_list_url,
data=json.dumps(entitlement_data),
content_type='application/json',
)
assert response.status_code == 201
results = response.data
course_entitlement = CourseEntitlement.objects.get(
user=self.user,
course_uuid=course_uuid
)
# Assert that enrollment mode is now verified
enrollment_mode = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id)[0]
assert enrollment_mode == course_entitlement.mode
assert course_entitlement.enrollment_course_run == enrollment
assert results == CourseEntitlementSerializer(course_entitlement).data
def test_non_staff_get_select_entitlements(self): def test_non_staff_get_select_entitlements(self):
not_staff_user = UserFactory() not_staff_user = UserFactory()
self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) self.client.login(username=not_staff_user.username, password=TEST_PASSWORD)
......
...@@ -53,6 +53,48 @@ class EntitlementViewSet(viewsets.ModelViewSet): ...@@ -53,6 +53,48 @@ class EntitlementViewSet(viewsets.ModelViewSet):
# to Admin users # to Admin users
return CourseEntitlement.objects.all().select_related('user').select_related('enrollment_course_run') return CourseEntitlement.objects.all().select_related('user').select_related('enrollment_course_run')
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
entitlement = serializer.instance
user = entitlement.user
# find all course_runs within the course
course_runs = get_course_runs_for_course(entitlement.course_uuid)
# check if the user has enrollments for any of the course_runs
user_run_enrollments = [
CourseEnrollment.get_enrollment(user, CourseKey.from_string(course_run.get('key')))
for course_run
in course_runs
if CourseEnrollment.get_enrollment(user, CourseKey.from_string(course_run.get('key')))
]
# filter to just enrollments that can be upgraded.
upgradeable_enrollments = [
enrollment
for enrollment
in user_run_enrollments
if enrollment.upgrade_deadline and enrollment.upgrade_deadline > timezone.now()
]
# if there is only one upgradeable enrollment, convert it from audit to the entitlement.mode
# if there is any ambiguity about which enrollment to upgrade
# (i.e. multiple upgradeable enrollments or no available upgradeable enrollment), dont enroll
if len(upgradeable_enrollments) == 1:
enrollment = upgradeable_enrollments[0]
log.info('Upgrading enrollment [%s] from audit to [%s] while adding entitlement for user [%s] for course [%s] ', enrollment, serializer.data.get('mode'), user.username, serializer.data.get('course_uuid'))
enrollment.update_enrollment(mode=entitlement.mode)
entitlement.set_enrollment(enrollment)
else:
log.info('No enrollment upgraded while adding entitlement for user [%s] for course [%s] ', user.username, serializer.data.get('course_uuid'))
headers = self.get_success_headers(serializer.data)
# Note, the entitlement is re-serialized before getting added to the Response, so that the 'modified' date reflects changes that occur when upgrading enrollment.
return Response(CourseEntitlementSerializer(entitlement).data, status=status.HTTP_201_CREATED, headers=headers)
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
""" """
Override the retrieve method to expire a record that is past the Override the retrieve method to expire a record that is past the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment