Commit 028cf8d4 by McKenzie Welter Committed by McKenzie Welter

Adding CourseEntitlement data to Course Model

parent 42de1b02
...@@ -19,10 +19,10 @@ from course_discovery.apps.api.fields import ImageField, StdImageSerializerField ...@@ -19,10 +19,10 @@ from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.api_client.lms import LMSAPIClient from course_discovery.apps.core.api_client.lms import LMSAPIClient
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import (FAQ, CorporateEndorsement, Course, CourseRun, Endorsement, from course_discovery.apps.course_metadata.models import (FAQ, CorporateEndorsement, Course, CourseEntitlement,
Image, Organization, Person, PersonSocialNetwork, PersonWork, CourseRun, Endorsement, Image, Organization, Person,
Position, Prerequisite, Program, ProgramType, Seat, Subject, PersonSocialNetwork, PersonWork, Position, Prerequisite,
Video) Program, ProgramType, Seat, SeatType, Subject, Video)
from course_discovery.apps.course_metadata.search_indexes import CourseIndex, CourseRunIndex, ProgramIndex from course_discovery.apps.course_metadata.search_indexes import CourseIndex, CourseRunIndex, ProgramIndex
User = get_user_model() User = get_user_model()
...@@ -390,6 +390,25 @@ class SeatSerializer(serializers.ModelSerializer): ...@@ -390,6 +390,25 @@ class SeatSerializer(serializers.ModelSerializer):
fields = ('type', 'price', 'currency', 'upgrade_deadline', 'credit_provider', 'credit_hours', 'sku',) fields = ('type', 'price', 'currency', 'upgrade_deadline', 'credit_provider', 'credit_hours', 'sku',)
class CourseEntitlementSerializer(serializers.ModelSerializer):
"""Serializer for the ``CourseEntitlement`` model."""
price = serializers.DecimalField(
decimal_places=CourseEntitlement.PRICE_FIELD_CONFIG['decimal_places'],
max_digits=CourseEntitlement.PRICE_FIELD_CONFIG['max_digits']
)
currency = serializers.SlugRelatedField(read_only=True, slug_field='code')
sku = serializers.CharField()
mode = serializers.SlugRelatedField(slug_field='name', queryset=SeatType.objects.all())
@classmethod
def prefetch_queryset(cls):
return CourseEntitlement.objects.all().select_related('currency', 'mode')
class Meta(object):
model = CourseEntitlement
fields = ('mode', 'price', 'currency', 'sku',)
class MinimalOrganizationSerializer(serializers.ModelSerializer): class MinimalOrganizationSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Organization model = Organization
...@@ -556,6 +575,7 @@ class ContainedCourseRunsSerializer(serializers.Serializer): ...@@ -556,6 +575,7 @@ class ContainedCourseRunsSerializer(serializers.Serializer):
class MinimalCourseSerializer(TimestampModelSerializer): class MinimalCourseSerializer(TimestampModelSerializer):
course_runs = MinimalCourseRunSerializer(many=True) course_runs = MinimalCourseRunSerializer(many=True)
entitlements = CourseEntitlementSerializer(many=True)
owners = MinimalOrganizationSerializer(many=True, source='authoring_organizations') owners = MinimalOrganizationSerializer(many=True, source='authoring_organizations')
image = ImageField(read_only=True, source='image_url') image = ImageField(read_only=True, source='image_url')
...@@ -567,12 +587,13 @@ class MinimalCourseSerializer(TimestampModelSerializer): ...@@ -567,12 +587,13 @@ class MinimalCourseSerializer(TimestampModelSerializer):
return queryset.select_related('partner').prefetch_related( return queryset.select_related('partner').prefetch_related(
'authoring_organizations', 'authoring_organizations',
'entitlements',
Prefetch('course_runs', queryset=MinimalCourseRunSerializer.prefetch_queryset(queryset=course_runs)), Prefetch('course_runs', queryset=MinimalCourseRunSerializer.prefetch_queryset(queryset=course_runs)),
) )
class Meta: class Meta:
model = Course model = Course
fields = ('key', 'uuid', 'title', 'course_runs', 'owners', 'image', 'short_description',) fields = ('key', 'uuid', 'title', 'course_runs', 'entitlements', 'owners', 'image', 'short_description',)
class CourseSerializer(MinimalCourseSerializer): class CourseSerializer(MinimalCourseSerializer):
...@@ -598,6 +619,7 @@ class CourseSerializer(MinimalCourseSerializer): ...@@ -598,6 +619,7 @@ class CourseSerializer(MinimalCourseSerializer):
'expected_learning_items', 'expected_learning_items',
'prerequisites', 'prerequisites',
'subjects', 'subjects',
'entitlements',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)), Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)), Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)), Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
......
...@@ -19,10 +19,11 @@ from waffle.testutils import override_switch ...@@ -19,10 +19,11 @@ from waffle.testutils import override_switch
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.serializers import (AffiliateWindowSerializer, CatalogSerializer, from course_discovery.apps.api.serializers import (AffiliateWindowSerializer, CatalogSerializer,
ContainedCourseRunsSerializer, ContainedCoursesSerializer, ContainedCourseRunsSerializer, ContainedCoursesSerializer,
CorporateEndorsementSerializer, CourseRunSearchSerializer, CorporateEndorsementSerializer, CourseEntitlementSerializer,
CourseRunSerializer, CourseRunWithProgramsSerializer, CourseRunSearchSerializer, CourseRunSerializer,
CourseSearchSerializer, CourseSerializer, CourseRunWithProgramsSerializer, CourseSearchSerializer,
CourseWithProgramsSerializer, EndorsementSerializer, FAQSerializer, CourseSerializer, CourseWithProgramsSerializer,
EndorsementSerializer, FAQSerializer,
FlattenedCourseRunWithCourseSerializer, ImageSerializer, FlattenedCourseRunWithCourseSerializer, ImageSerializer,
MinimalCourseRunSerializer, MinimalCourseSerializer, MinimalCourseRunSerializer, MinimalCourseSerializer,
MinimalOrganizationSerializer, MinimalProgramCourseSerializer, MinimalOrganizationSerializer, MinimalProgramCourseSerializer,
...@@ -123,6 +124,7 @@ class MinimalCourseSerializerTests(SiteMixin, TestCase): ...@@ -123,6 +124,7 @@ class MinimalCourseSerializerTests(SiteMixin, TestCase):
'uuid': str(course.uuid), 'uuid': str(course.uuid),
'title': course.title, 'title': course.title,
'course_runs': MinimalCourseRunSerializer(course.course_runs, many=True, context=context).data, 'course_runs': MinimalCourseRunSerializer(course.course_runs, many=True, context=context).data,
'entitlements': [],
'owners': MinimalOrganizationSerializer(course.authoring_organizations, many=True, context=context).data, 'owners': MinimalOrganizationSerializer(course.authoring_organizations, many=True, context=context).data,
'image': ImageField().to_representation(course.image_url), 'image': ImageField().to_representation(course.image_url),
'short_description': course.short_description 'short_description': course.short_description
...@@ -161,6 +163,7 @@ class CourseSerializerTests(MinimalCourseSerializerTests): ...@@ -161,6 +163,7 @@ class CourseSerializerTests(MinimalCourseSerializerTests):
}) })
), ),
'course_runs': CourseRunSerializer(course.course_runs, many=True, context={'request': request}).data, 'course_runs': CourseRunSerializer(course.course_runs, many=True, context={'request': request}).data,
'entitlements': CourseEntitlementSerializer(many=True).data,
'owners': OrganizationSerializer(course.authoring_organizations, many=True).data, 'owners': OrganizationSerializer(course.authoring_organizations, many=True).data,
'prerequisites_raw': course.prerequisites_raw, 'prerequisites_raw': course.prerequisites_raw,
'syllabus_raw': course.syllabus_raw, 'syllabus_raw': course.syllabus_raw,
......
...@@ -179,7 +179,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -179,7 +179,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# to be included. # to be included.
filtered_course_run = CourseRunFactory(course=course) filtered_course_run = CourseRunFactory(course=course)
with self.assertNumQueries(20): with self.assertNumQueries(21):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
......
...@@ -32,7 +32,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -32,7 +32,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(21): with self.assertNumQueries(23):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.data, self.serialize_course(self.course))
...@@ -41,7 +41,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -41,7 +41,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course with UUID. """ """ Verify the endpoint returns the details for a single course with UUID. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.uuid}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.uuid})
with self.assertNumQueries(21): with self.assertNumQueries(23):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.data, self.serialize_course(self.course))
...@@ -50,7 +50,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -50,7 +50,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """ """ Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(14): with self.assertNumQueries(15):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), []) self.assertEqual(response.data.get('programs'), [])
...@@ -63,7 +63,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -63,7 +63,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1' url += '?include_deleted_programs=1'
with self.assertNumQueries(25): with self.assertNumQueries(27):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual( self.assertEqual(
...@@ -199,7 +199,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -199,7 +199,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list') url = reverse('api:v1:course-list')
with self.assertNumQueries(27): with self.assertNumQueries(29):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual( self.assertListEqual(
...@@ -215,7 +215,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -215,7 +215,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query) url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(41): with self.assertNumQueries(47):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
...@@ -226,7 +226,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -226,7 +226,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses]) keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(40): with self.assertNumQueries(46):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
...@@ -237,7 +237,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -237,7 +237,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
uuids = ','.join([str(course.uuid) for course in courses]) uuids = ','.join([str(course.uuid) for course in courses])
url = '{root}?uuids={uuids}'.format(root=reverse('api:v1:course-list'), uuids=uuids) url = '{root}?uuids={uuids}'.format(root=reverse('api:v1:course-list'), uuids=uuids)
with self.assertNumQueries(40): with self.assertNumQueries(46):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
...@@ -88,7 +88,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -88,7 +88,7 @@ class TestProgramViewSet(SerializationMixin):
def test_retrieve(self, django_assert_num_queries): def test_retrieve(self, django_assert_num_queries):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = self.create_program() program = self.create_program()
with django_assert_num_queries(40): with django_assert_num_queries(41):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed # property does not have the right values while being indexed
del program._course_run_weeks_to_complete del program._course_run_weeks_to_complete
...@@ -114,7 +114,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -114,7 +114,7 @@ class TestProgramViewSet(SerializationMixin):
partner=self.partner) partner=self.partner)
# property does not have the right values while being indexed # property does not have the right values while being indexed
del program._course_run_weeks_to_complete del program._course_run_weeks_to_complete
with django_assert_num_queries(29): with django_assert_num_queries(30):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program) assert response.data == self.serialize_program(program)
assert course_list == list(program.courses.all()) # pylint: disable=no-member assert course_list == list(program.courses.all()) # pylint: disable=no-member
...@@ -123,7 +123,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -123,7 +123,7 @@ class TestProgramViewSet(SerializationMixin):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """ """ Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory(partner=self.partner) course = CourseFactory(partner=self.partner)
program = ProgramFactory(courses=[course], partner=self.partner) program = ProgramFactory(courses=[course], partner=self.partner)
with django_assert_num_queries(22): with django_assert_num_queries(23):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program) assert response.data == self.serialize_program(program)
...@@ -150,7 +150,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -150,7 +150,7 @@ class TestProgramViewSet(SerializationMixin):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = [self.create_program() for __ in range(3)] expected = [self.create_program() for __ in range(3)]
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected, 15) self.assert_list_results(self.list_path, expected, 16)
# Verify that repeated list requests use the cache. # Verify that repeated list requests use the cache.
self.assert_list_results(self.list_path, expected, 4) self.assert_list_results(self.list_path, expected, 4)
...@@ -274,13 +274,13 @@ class TestProgramViewSet(SerializationMixin): ...@@ -274,13 +274,13 @@ class TestProgramViewSet(SerializationMixin):
program.marketing_slug = SLUG program.marketing_slug = SLUG
program.save() program.save()
self.assert_list_results(url, [program], 15) self.assert_list_results(url, [program], 16)
def test_list_exclude_utm(self): def test_list_exclude_utm(self):
""" Verify the endpoint returns marketing URLs without UTM parameters. """ """ Verify the endpoint returns marketing URLs without UTM parameters. """
url = self.list_path + '?exclude_utm=1' url = self.list_path + '?exclude_utm=1'
program = self.create_program() program = self.create_program()
self.assert_list_results(url, [program], 14, extra_context={'exclude_utm': 1}) self.assert_list_results(url, [program], 15, extra_context={'exclude_utm': 1})
def test_minimal_serializer_use(self): def test_minimal_serializer_use(self):
""" Verify that the list view uses the minimal serializer. """ """ Verify that the list view uses the minimal serializer. """
......
...@@ -12,7 +12,7 @@ from course_discovery.apps.core.models import Currency ...@@ -12,7 +12,7 @@ from course_discovery.apps.core.models import Currency
from course_discovery.apps.course_metadata.choices import CourseRunPacing, CourseRunStatus from course_discovery.apps.course_metadata.choices import CourseRunPacing, CourseRunStatus
from course_discovery.apps.course_metadata.data_loaders import AbstractDataLoader from course_discovery.apps.course_metadata.data_loaders import AbstractDataLoader
from course_discovery.apps.course_metadata.models import ( from course_discovery.apps.course_metadata.models import (
Course, CourseRun, Organization, Program, ProgramType, Seat, Video Course, CourseEntitlement, CourseRun, Organization, Program, ProgramType, Seat, SeatType, Video
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -248,16 +248,18 @@ class CoursesApiDataLoader(AbstractDataLoader): ...@@ -248,16 +248,18 @@ class CoursesApiDataLoader(AbstractDataLoader):
class EcommerceApiDataLoader(AbstractDataLoader): class EcommerceApiDataLoader(AbstractDataLoader):
""" Loads course seats from the E-Commerce API. """ """ Loads course seats and entitlements from the E-Commerce API. """
def ingest(self): def ingest(self):
logger.info('Refreshing course seats from %s...', self.partner.ecommerce_api_url) logger.info('Refreshing course seats from %s...', self.partner.ecommerce_api_url)
initial_page = 1 initial_page = 1
response = self._make_request(initial_page) course_runs = self._request_course_runs(initial_page)
count = response['count'] entitlements = self._request_entitlments(initial_page)
count = course_runs['count'] + entitlements['count']
pages = math.ceil(count / self.PAGE_SIZE) pages = math.ceil(count / self.PAGE_SIZE)
self._process_response(response) self._process_course_runs(course_runs)
self._process_entitlements(entitlements)
pagerange = range(initial_page + 1, pages + 1) pagerange = range(initial_page + 1, pages + 1)
...@@ -266,23 +268,32 @@ class EcommerceApiDataLoader(AbstractDataLoader): ...@@ -266,23 +268,32 @@ class EcommerceApiDataLoader(AbstractDataLoader):
for page in pagerange: for page in pagerange:
executor.submit(self._load_data, page) executor.submit(self._load_data, page)
else: else:
for future in [executor.submit(self._make_request, page) for page in pagerange]: for future in [executor.submit(self._request_course_runs, page) for page in pagerange]:
response = future.result() response = future.result()
self._process_response(response) self._process_course_runs(response)
for future in [executor.submit(self._request_entitlments, page) for page in pagerange]:
response = future.result()
self._process_entitlements(response)
logger.info('Retrieved %d course seats from %s.', count, self.partner.ecommerce_api_url) logger.info('Retrieved %d course seats and %d course entitlements from %s.', course_runs['count'],
entitlements['count'], self.partner.ecommerce_api_url)
self.delete_orphans() self.delete_orphans()
def _load_data(self, page): # pragma: no cover def _load_data(self, page): # pragma: no cover
"""Make a request for the given page and process the response.""" """Make a request for the given page and process the response."""
response = self._make_request(page) course_runs = self._request_course_runs(page)
self._process_response(response) self._process_course_runs(course_runs)
entitlements = self._request_entitlments(page)
self._process_entitlements(entitlements)
def _make_request(self, page): def _request_course_runs(self, page):
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, include_products=True) return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, include_products=True)
def _process_response(self, response): def _request_entitlments(self, page):
return self.api_client.products().get(page=page, page_size=self.PAGE_SIZE, product_class='Course Entitlement')
def _process_course_runs(self, response):
results = response['results'] results = response['results']
logger.info('Retrieved %d course seats...', len(results)) logger.info('Retrieved %d course seats...', len(results))
...@@ -290,6 +301,16 @@ class EcommerceApiDataLoader(AbstractDataLoader): ...@@ -290,6 +301,16 @@ class EcommerceApiDataLoader(AbstractDataLoader):
body = self.clean_strings(body) body = self.clean_strings(body)
self.update_seats(body) self.update_seats(body)
def _process_entitlements(self, response):
results = response['results']
logger.info('Retrieved %d course entitlements...', len(results))
skus = []
for body in results:
body = self.clean_strings(body)
skus.append(self.update_entitlement(body))
CourseEntitlement.objects.exclude(sku__in=skus).delete()
def update_seats(self, body): def update_seats(self, body):
course_run_key = body['id'] course_run_key = body['id']
try: try:
...@@ -340,6 +361,44 @@ class EcommerceApiDataLoader(AbstractDataLoader): ...@@ -340,6 +361,44 @@ class EcommerceApiDataLoader(AbstractDataLoader):
course_run.seats.update_or_create(type=seat_type, credit_provider=credit_provider, currency=currency, course_run.seats.update_or_create(type=seat_type, credit_provider=credit_provider, currency=currency,
defaults=defaults) defaults=defaults)
def update_entitlement(self, body):
attributes = {attribute['name']: attribute['value'] for attribute in body['attribute_values']}
course_uuid = attributes.get('UUID')
try:
course = Course.objects.get(uuid=course_uuid)
except Course.DoesNotExist:
msg = 'Could not find course {uuid}'.format(uuid=course_uuid)
logger.warning(msg)
return None
stock_record = body['stockrecords'][0]
currency_code = stock_record['price_currency']
price = Decimal(stock_record['price_excl_tax'])
sku = stock_record['partner_sku']
try:
currency = Currency.objects.get(code=currency_code)
except Currency.DoesNotExist:
msg = 'Could not find currency {code}'.format(code=currency_code)
logger.warning(msg)
return None
mode_name = attributes.get('certificate_type')
try:
mode = SeatType.objects.get(name=mode_name)
except SeatType.DoesNotExist:
msg = 'Could not find course entitlement mode {mode}'.format(mode=mode_name)
logger.warning(msg)
return None
defaults = {
'price': price,
'currency': currency,
'sku': sku
}
course.entitlements.update_or_create(mode=mode, defaults=defaults)
return sku
def get_certificate_type(self, product): def get_certificate_type(self, product):
return next( return next(
(att['value'] for att in product['attribute_values'] if att['name'] == 'certificate_type'), (att['value'] for att in product['attribute_values'] if att['name'] == 'certificate_type'),
......
...@@ -14,7 +14,9 @@ from course_discovery.apps.course_metadata.data_loaders.api import ( ...@@ -14,7 +14,9 @@ from course_discovery.apps.course_metadata.data_loaders.api import (
) )
from course_discovery.apps.course_metadata.data_loaders.tests import JPEG, JSON, mock_data from course_discovery.apps.course_metadata.data_loaders.tests import JPEG, JSON, mock_data
from course_discovery.apps.course_metadata.data_loaders.tests.mixins import ApiClientTestMixin, DataLoaderTestMixin from course_discovery.apps.course_metadata.data_loaders.tests.mixins import ApiClientTestMixin, DataLoaderTestMixin
from course_discovery.apps.course_metadata.models import Course, CourseRun, Organization, Program, ProgramType, Seat from course_discovery.apps.course_metadata.models import (
Course, CourseEntitlement, CourseRun, Organization, Program, ProgramType, Seat, SeatType
)
from course_discovery.apps.course_metadata.tests.factories import ( from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, ImageFactory, OrganizationFactory, SeatFactory, VideoFactory CourseFactory, CourseRunFactory, ImageFactory, OrganizationFactory, SeatFactory, VideoFactory
) )
...@@ -329,7 +331,7 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC ...@@ -329,7 +331,7 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC
def api_url(self): def api_url(self):
return self.partner.ecommerce_api_url return self.partner.ecommerce_api_url
def mock_api(self): def mock_courses_api(self):
# Create existing seats to be removed by ingest # Create existing seats to be removed by ingest
audit_run = CourseRunFactory(title_override='audit', key='audit/course/run') audit_run = CourseRunFactory(title_override='audit', key='audit/course/run')
verified_run = CourseRunFactory(title_override='verified', key='verified/course/run') verified_run = CourseRunFactory(title_override='verified', key='verified/course/run')
...@@ -351,6 +353,45 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC ...@@ -351,6 +353,45 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC
) )
return bodies return bodies
def mock_products_api(self, alt_course=None, alt_currency=None, alt_mode=None):
""" Return a new Course Entitlement to be added by ingest """
course = CourseFactory()
bodies = [
{
"structure": "child",
"product_class": "Course Entitlement",
"price": "10.00",
"expires": None,
"attribute_values": [
{
"name": "certificate_type",
"value": alt_mode if alt_mode else "verified",
},
{
"name": "UUID",
"value": alt_course if alt_course else str(course.uuid),
}
],
"is_available_to_buy": True,
"stockrecords": [
{
"price_currency": alt_currency if alt_currency else "USD",
"price_excl_tax": "10.00",
"partner_sku": "sku132",
}
]
}
]
url = '{url}products/'.format(url=self.api_url)
responses.add_callback(
responses.GET,
url,
callback=mock_api_callback(url, bodies),
content_type=JSON
)
return bodies
def assert_seats_loaded(self, body): def assert_seats_loaded(self, body):
""" Assert a Seat corresponding to the specified data body was properly loaded into the database. """ """ Assert a Seat corresponding to the specified data body was properly loaded into the database. """
course_run = CourseRun.objects.get(key=body['id']) course_run = CourseRun.objects.get(key=body['id'])
...@@ -393,12 +434,36 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC ...@@ -393,12 +434,36 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC
self.assertEqual(seat.upgrade_deadline, upgrade_deadline) self.assertEqual(seat.upgrade_deadline, upgrade_deadline)
self.assertEqual(seat.sku, sku) self.assertEqual(seat.sku, sku)
def assert_entitlements_loaded(self, body):
""" Assert a Course Entitlement was loaded into the database for each entry in the specified data body. """
self.assertEqual(CourseEntitlement.objects.count(), len(body))
for datum in body:
attributes = {attribute['name']: attribute['value'] for attribute in datum['attribute_values']}
course = Course.objects.get(uuid=attributes['UUID'])
stock_record = datum['stockrecords'][0]
price_currency = stock_record['price_currency']
price = Decimal(stock_record['price_excl_tax'])
sku = stock_record['partner_sku']
mode_name = attributes['certificate_type']
mode = SeatType.objects.get(name=mode_name)
entitlement = course.entitlements.get(mode=mode)
self.assertEqual(entitlement.course, course)
self.assertEqual(entitlement.mode, mode)
self.assertEqual(entitlement.price, price)
self.assertEqual(entitlement.currency.code, price_currency)
self.assertEqual(entitlement.sku, sku)
@responses.activate @responses.activate
def test_ingest(self): def test_ingest(self):
""" Verify the method ingests data from the E-Commerce API. """ """ Verify the method ingests data from the E-Commerce API. """
api_data = self.mock_api() courses_api_data = self.mock_courses_api()
loaded_course_run_data = api_data[:-1] loaded_course_run_data = courses_api_data[:-1]
loaded_seat_data = api_data[:-2] loaded_seat_data = courses_api_data[:-2]
products_api_data = self.mock_products_api()
self.assertEqual(CourseRun.objects.count(), len(loaded_course_run_data)) self.assertEqual(CourseRun.objects.count(), len(loaded_course_run_data))
...@@ -409,14 +474,38 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC ...@@ -409,14 +474,38 @@ class EcommerceApiDataLoaderTests(ApiClientTestMixin, DataLoaderTestMixin, TestC
self.loader.ingest() self.loader.ingest()
# Verify the API was called with the correct authorization header # Verify the API was called with the correct authorization header
self.assert_api_called(1) self.assert_api_called(2)
for datum in loaded_seat_data: for datum in loaded_seat_data:
self.assert_seats_loaded(datum) self.assert_seats_loaded(datum)
self.assert_entitlements_loaded(products_api_data)
# Verify multiple calls to ingest data do NOT result in data integrity errors. # Verify multiple calls to ingest data do NOT result in data integrity errors.
self.loader.ingest() self.loader.ingest()
@responses.activate
@ddt.data(
('a01354b1-c0de-4a6b-c5de-ab5c6d869e76', None, None),
(None, "NRC", None),
(None, None, "notamode")
)
@ddt.unpack
def test_ingest_fails(self, alt_course, alt_currency, alt_mode):
""" Verify the proper warnings are logged when data objects are not present. """
self.mock_courses_api()
self.mock_products_api(alt_course=alt_course, alt_currency=alt_currency, alt_mode=alt_mode)
with mock.patch(LOGGER_PATH) as mock_logger:
self.loader.ingest()
msg = 'Could not find '
if alt_course:
msg += 'course ' + alt_course
elif alt_currency:
msg += 'currency ' + alt_currency
else:
msg += 'course entitlement mode ' + alt_mode
mock_logger.warning.assert_called_with(msg)
@ddt.unpack @ddt.unpack
@ddt.data( @ddt.data(
({"attribute_values": []}, Seat.AUDIT), ({"attribute_values": []}, Seat.AUDIT),
......
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2017-11-08 16:14
from __future__ import unicode_literals
import django.db.models.deletion
import django_extensions.db.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0007_auto_20171004_1133'),
('course_metadata', '0067_auto_20171108_1432'),
]
operations = [
migrations.CreateModel(
name='CourseEntitlement',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created', django_extensions.db.fields.CreationDateTimeField(auto_now_add=True, verbose_name='created')),
('modified', django_extensions.db.fields.ModificationDateTimeField(auto_now=True, verbose_name='modified')),
('price', models.DecimalField(decimal_places=2, default=0.0, max_digits=10)),
('sku', models.CharField(blank=True, max_length=128, null=True)),
('course', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='entitlements', to='course_metadata.Course')),
('currency', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.Currency')),
('mode', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='course_metadata.SeatType')),
],
),
migrations.AlterUniqueTogether(
name='courseentitlement',
unique_together=set([('course', 'mode')]),
),
]
...@@ -766,6 +766,26 @@ class Seat(TimeStampedModel): ...@@ -766,6 +766,26 @@ class Seat(TimeStampedModel):
) )
class CourseEntitlement(TimeStampedModel):
""" Model storing product metadata for a Course. """
PRICE_FIELD_CONFIG = {
'decimal_places': 2,
'max_digits': 10,
'null': False,
'default': 0.00,
}
course = models.ForeignKey(Course, related_name='entitlements')
mode = models.ForeignKey(SeatType)
price = models.DecimalField(**PRICE_FIELD_CONFIG)
currency = models.ForeignKey(Currency)
sku = models.CharField(max_length=128, null=True, blank=True)
class Meta(object):
unique_together = (
('course', 'mode')
)
class Endorsement(TimeStampedModel): class Endorsement(TimeStampedModel):
endorser = models.ForeignKey(Person, blank=False, null=False) endorser = models.ForeignKey(Person, blank=False, null=False)
quote = models.TextField(blank=False, null=False) quote = models.TextField(blank=False, null=False)
......
...@@ -348,3 +348,14 @@ class PersonWorkFactory(factory.django.DjangoModelFactory): ...@@ -348,3 +348,14 @@ class PersonWorkFactory(factory.django.DjangoModelFactory):
model = PersonWork model = PersonWork
person = factory.SubFactory(PersonFactory) person = factory.SubFactory(PersonFactory)
class CourseEntitlementFactory(factory.DjangoModelFactory):
mode = factory.SubFactory(SeatTypeFactory)
price = FuzzyDecimal(0.0, 650.0)
currency = factory.Iterator(Currency.objects.all())
sku = FuzzyText(length=8)
course = factory.SubFactory(CourseFactory)
class Meta:
model = CourseEntitlement
...@@ -995,6 +995,23 @@ class ProgramTypeTests(TestCase): ...@@ -995,6 +995,23 @@ class ProgramTypeTests(TestCase):
self.assertEqual(str(program_type), program_type.name) self.assertEqual(str(program_type), program_type.name)
class CourseEntitlementTests(TestCase):
""" Tests of the CourseEntitlement model. """
def setUp(self):
super(CourseEntitlementTests, self).setUp()
self.course = factories.CourseFactory()
self.mode = factories.SeatTypeFactory()
def test_unique_constraint(self):
"""
Verify that a CourseEntitlement does not allow multiple skus or prices for the same course and mode.
"""
factories.CourseEntitlementFactory(course=self.course, mode=self.mode)
with self.assertRaises(IntegrityError):
factories.CourseEntitlementFactory(course=self.course, mode=self.mode)
class EndorsementTests(TestCase): class EndorsementTests(TestCase):
""" Tests of the Endorsement model. """ """ Tests of the Endorsement model. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment