Commit 3dd051a1 by Renzo Lucioni Committed by Clinton Blackburn

Roll back attempted program list optimizations

parent 54feb857
# pylint: disable=abstract-method # pylint: disable=abstract-method
import datetime
import json import json
from urllib.parse import urlencode from urllib.parse import urlencode
import pytz
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer
...@@ -83,7 +85,7 @@ PREFETCH_FIELDS = { ...@@ -83,7 +85,7 @@ PREFETCH_FIELDS = {
} }
SELECT_RELATED_FIELDS = { SELECT_RELATED_FIELDS = {
'course': ['level_type', 'video', 'partner', ], 'course': ['level_type', 'video', ],
'course_run': ['course', 'language', 'video', ], 'course_run': ['course', 'language', 'video', ],
'program': ['type', 'video', 'partner', ], 'program': ['type', 'video', 'partner', ],
} }
...@@ -225,21 +227,13 @@ class SeatSerializer(serializers.ModelSerializer): ...@@ -225,21 +227,13 @@ class SeatSerializer(serializers.ModelSerializer):
fields = ('type', 'price', 'currency', 'upgrade_deadline', 'credit_provider', 'credit_hours',) fields = ('type', 'price', 'currency', 'upgrade_deadline', 'credit_provider', 'credit_hours',)
class MinimalOrganizationSerializer(serializers.ModelSerializer): class OrganizationSerializer(TaggitSerializer, serializers.ModelSerializer):
class Meta:
model = Organization
fields = ('uuid', 'key', 'name',)
class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
"""Serializer for the ``Organization`` model.""" """Serializer for the ``Organization`` model."""
tags = TagListSerializerField() tags = TagListSerializerField()
class Meta(MinimalOrganizationSerializer.Meta): class Meta(object):
model = Organization model = Organization
fields = MinimalOrganizationSerializer.Meta.fields + ( fields = ('key', 'name', 'description', 'homepage_url', 'tags', 'logo_image_url', 'marketing_url')
'description', 'homepage_url', 'tags', 'logo_image_url', 'marketing_url',
)
class CatalogSerializer(serializers.ModelSerializer): class CatalogSerializer(serializers.ModelSerializer):
...@@ -277,13 +271,7 @@ class NestedProgramSerializer(serializers.ModelSerializer): ...@@ -277,13 +271,7 @@ class NestedProgramSerializer(serializers.ModelSerializer):
read_only_fields = ('uuid', 'marketing_url',) read_only_fields = ('uuid', 'marketing_url',)
class MinimalCourseRunSerializer(TimestampModelSerializer): class CourseRunSerializer(TimestampModelSerializer):
class Meta:
model = CourseRun
fields = ('key', 'uuid', 'title',)
class CourseRunSerializer(MinimalCourseRunSerializer):
"""Serializer for the ``CourseRun`` model.""" """Serializer for the ``CourseRun`` model."""
course = serializers.SlugRelatedField(read_only=True, slug_field='key') course = serializers.SlugRelatedField(read_only=True, slug_field='key')
content_language = serializers.SlugRelatedField( content_language = serializers.SlugRelatedField(
...@@ -299,13 +287,13 @@ class CourseRunSerializer(MinimalCourseRunSerializer): ...@@ -299,13 +287,13 @@ class CourseRunSerializer(MinimalCourseRunSerializer):
marketing_url = serializers.SerializerMethodField() marketing_url = serializers.SerializerMethodField()
level_type = serializers.SlugRelatedField(read_only=True, slug_field='name') level_type = serializers.SlugRelatedField(read_only=True, slug_field='name')
class Meta(MinimalCourseRunSerializer.Meta): class Meta(object):
model = CourseRun model = CourseRun
fields = MinimalCourseRunSerializer.Meta.fields + ( fields = (
'course', 'short_description', 'full_description', 'start', 'end', 'enrollment_start', 'enrollment_end', 'course', 'key', 'title', 'short_description', 'full_description', 'start', 'end',
'announcement', 'image', 'video', 'seats', 'content_language', 'transcript_languages', 'instructors', 'enrollment_start', 'enrollment_end', 'announcement', 'image', 'video', 'seats',
'staff', 'pacing_type', 'min_effort', 'max_effort', 'modified', 'marketing_url', 'level_type', 'content_language', 'transcript_languages', 'instructors', 'staff',
'availability', 'pacing_type', 'min_effort', 'max_effort', 'modified', 'marketing_url', 'level_type', 'availability',
) )
def get_marketing_url(self, obj): def get_marketing_url(self, obj):
...@@ -332,16 +320,7 @@ class ContainedCourseRunsSerializer(serializers.Serializer): ...@@ -332,16 +320,7 @@ class ContainedCourseRunsSerializer(serializers.Serializer):
) )
class MinimalCourseSerializer(TimestampModelSerializer): class CourseSerializer(TimestampModelSerializer):
course_runs = MinimalCourseRunSerializer(many=True)
owners = MinimalOrganizationSerializer(many=True, source='authoring_organizations')
class Meta:
model = Course
fields = ('key', 'uuid', 'title', 'course_runs', 'owners',)
class CourseSerializer(MinimalCourseSerializer):
"""Serializer for the ``Course`` model.""" """Serializer for the ``Course`` model."""
level_type = serializers.SlugRelatedField(read_only=True, slug_field='name') level_type = serializers.SlugRelatedField(read_only=True, slug_field='name')
subjects = SubjectSerializer(many=True) subjects = SubjectSerializer(many=True)
...@@ -354,11 +333,12 @@ class CourseSerializer(MinimalCourseSerializer): ...@@ -354,11 +333,12 @@ class CourseSerializer(MinimalCourseSerializer):
course_runs = CourseRunSerializer(many=True) course_runs = CourseRunSerializer(many=True)
marketing_url = serializers.SerializerMethodField() marketing_url = serializers.SerializerMethodField()
class Meta(MinimalCourseSerializer.Meta): class Meta(object):
model = Course model = Course
fields = MinimalCourseSerializer.Meta.fields + ( fields = (
'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites', 'key', 'title', 'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites',
'expected_learning_items', 'image', 'video', 'sponsors', 'modified', 'marketing_url', 'expected_learning_items', 'image', 'video', 'owners', 'sponsors', 'modified', 'course_runs',
'marketing_url',
) )
def get_marketing_url(self, obj): def get_marketing_url(self, obj):
...@@ -390,55 +370,29 @@ class ContainedCoursesSerializer(serializers.Serializer): ...@@ -390,55 +370,29 @@ class ContainedCoursesSerializer(serializers.Serializer):
) )
class ProgramCourseSerializer(MinimalCourseSerializer): class ProgramCourseSerializer(CourseSerializer):
"""Serializer used to filter out excluded course runs in a course associated with the program""" """Serializer used to filter out excluded course runs in a course associated with the program"""
course_runs = serializers.SerializerMethodField() course_runs = serializers.SerializerMethodField()
def get_course_runs(self, course): def get_course_runs(self, course):
program = self.context['program'] course_runs = self.context['course_runs']
course_runs = list(course.course_runs.all()) course_runs = [course_run for course_run in course_runs if course_run.course == course]
excluded_course_runs = list(program.excluded_course_runs.all())
course_runs = [course_run for course_run in course_runs if course_run not in excluded_course_runs]
if self.context.get('published_course_runs_only'): if self.context.get('published_course_runs_only'):
course_runs = [course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published] course_runs = [course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published]
return MinimalCourseRunSerializer( return CourseRunSerializer(
course_runs, course_runs,
many=True, many=True,
context={'request': self.context.get('request')} context={'request': self.context.get('request')}
).data ).data
class MinimalProgramSerializer(serializers.ModelSerializer): class ProgramSerializer(serializers.ModelSerializer):
authoring_organizations = MinimalOrganizationSerializer(many=True)
banner_image = StdImageSerializerField()
courses = serializers.SerializerMethodField() courses = serializers.SerializerMethodField()
authoring_organizations = OrganizationSerializer(many=True)
type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all()) type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all())
banner_image = StdImageSerializerField()
def get_courses(self, program):
course_serializer = ProgramCourseSerializer(
program.courses.all(),
many=True,
context={
'request': self.context.get('request'),
'program': program,
'published_course_runs_only': self.context.get('published_course_runs_only'),
}
)
return course_serializer.data
class Meta:
model = Program
fields = (
'uuid', 'title', 'subtitle', 'type', 'status', 'marketing_slug', 'marketing_url', 'banner_image', 'courses',
'authoring_organizations', 'card_image_url',
)
read_only_fields = ('uuid', 'marketing_url', 'banner_image')
class ProgramSerializer(MinimalProgramSerializer):
video = VideoSerializer() video = VideoSerializer()
expected_learning_items = serializers.SlugRelatedField(many=True, read_only=True, slug_field='value') expected_learning_items = serializers.SlugRelatedField(many=True, read_only=True, slug_field='value')
faq = FAQSerializer(many=True) faq = FAQSerializer(many=True)
...@@ -457,14 +411,90 @@ class ProgramSerializer(MinimalProgramSerializer): ...@@ -457,14 +411,90 @@ class ProgramSerializer(MinimalProgramSerializer):
subjects = SubjectSerializer(many=True) subjects = SubjectSerializer(many=True)
staff = PersonSerializer(many=True) staff = PersonSerializer(many=True)
class Meta(MinimalProgramSerializer.Meta): def get_courses(self, program):
courses, course_runs = self.sort_courses(program)
course_serializer = ProgramCourseSerializer(
courses,
many=True,
context={
'request': self.context.get('request'),
'program': program,
'published_course_runs_only': self.context.get('published_course_runs_only'),
'course_runs': course_runs,
}
)
return course_serializer.data
def sort_courses(self, program):
"""
Sorting by enrollment start then by course start yields a list ordered by course start, with
ties broken by enrollment start. This works because Python sorting is stable: two objects with
equal keys appear in the same order in sorted output as they appear in the input.
Courses are only created if there's at least one course run belonging to that course, so
course_runs should never be empty. If it is, key functions in this method attempting to find the
min of an empty sequence will raise a ValueError.
"""
course_runs = program.course_runs.select_related(*SELECT_RELATED_FIELDS['course_run'])
course_runs = course_runs.prefetch_related(*PREFETCH_FIELDS['course_run'])
course_runs = list(course_runs)
def min_run_enrollment_start(course):
# Enrollment starts may be empty. When this is the case, we make the same assumption as
# the LMS: no enrollment_start is equivalent to (offset-aware) datetime.datetime.min.
min_datetime = datetime.datetime.min.replace(tzinfo=pytz.UTC)
# Course runs excluded from the program are excluded here, too.
#
# If this becomes a candidate for optimization in the future, be careful sorting null values
# in the database. PostgreSQL and MySQL sort null values as if they are higher than non-null
# values, while SQLite does the opposite.
#
# For more, refer to https://docs.djangoproject.com/en/1.10/ref/models/querysets/#latest.
_course_runs = [course_run for course_run in course_runs if course_run.course == course]
# Return early if we have no course runs since min() will fail.
if not _course_runs:
return min_datetime
run = min(_course_runs, key=lambda run: run.enrollment_start or min_datetime)
return run.enrollment_start or min_datetime
def min_run_start(course):
# Course starts may be empty. Since this means the course can't be started, missing course
# start date is equivalent to (offset-aware) datetime.datetime.max.
max_datetime = datetime.datetime.max.replace(tzinfo=pytz.UTC)
_course_runs = [course_run for course_run in course_runs if course_run.course == course]
# Return early if we have no course runs since min() will fail.
if not _course_runs:
return max_datetime
run = min(_course_runs, key=lambda run: run.start or max_datetime)
return run.start or max_datetime
courses = list(program.courses.all())
courses.sort(key=min_run_enrollment_start)
courses.sort(key=min_run_start)
return courses, course_runs
class Meta:
model = Program model = Program
fields = MinimalProgramSerializer.Meta.fields + ( fields = (
'overview', 'weeks_to_complete', 'min_hours_effort_per_week', 'max_hours_effort_per_week', 'video', 'uuid', 'title', 'subtitle', 'type', 'status', 'marketing_slug', 'marketing_url', 'courses',
'overview', 'weeks_to_complete', 'min_hours_effort_per_week', 'max_hours_effort_per_week',
'authoring_organizations', 'banner_image', 'banner_image_url', 'card_image_url', 'video',
'expected_learning_items', 'faq', 'credit_backing_organizations', 'corporate_endorsements', 'expected_learning_items', 'faq', 'credit_backing_organizations', 'corporate_endorsements',
'job_outlook_items', 'individual_endorsements', 'languages', 'transcript_languages', 'subjects', 'job_outlook_items', 'individual_endorsements', 'languages', 'transcript_languages', 'subjects',
'price_ranges', 'staff', 'credit_redemption_overview', 'price_ranges', 'staff', 'credit_redemption_overview'
) )
read_only_fields = ('uuid', 'marketing_url', 'banner_image')
class AffiliateWindowSerializer(serializers.ModelSerializer): class AffiliateWindowSerializer(serializers.ModelSerializer):
......
from django.test import TestCase from django.test import TestCase
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField from course_discovery.apps.api.fields import ImageField
from course_discovery.apps.api.tests.test_serializers import make_request
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory
class ImageFieldTests(TestCase): class ImageFieldTests(TestCase):
...@@ -16,27 +13,3 @@ class ImageFieldTests(TestCase): ...@@ -16,27 +13,3 @@ class ImageFieldTests(TestCase):
'width': None 'width': None
} }
self.assertEqual(ImageField().to_representation(value), expected) self.assertEqual(ImageField().to_representation(value), expected)
# pylint: disable=no-member
class StdImageSerializerFieldTests(TestCase):
def test_to_representation(self):
request = make_request()
# TODO Create test-only model to avoid unnecessary dependency on Program model.
program = ProgramFactory(banner_image=make_image_file('test.jpg'))
field = StdImageSerializerField()
field._context = {'request': request} # pylint: disable=protected-access
expected = {
size_key: {
'url': '{}{}'.format(
'http://testserver',
getattr(program.banner_image, size_key).url
),
'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height']
}
for size_key in program.banner_image.field.variations
}
self.assertDictEqual(field.to_representation(program.banner_image), expected)
import unittest
from datetime import datetime from datetime import datetime
from urllib.parse import urlencode from urllib.parse import urlencode
...@@ -7,16 +8,14 @@ from haystack.query import SearchQuerySet ...@@ -7,16 +8,14 @@ from haystack.query import SearchQuerySet
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField from course_discovery.apps.api.fields import ImageField
from course_discovery.apps.api.serializers import ( from course_discovery.apps.api.serializers import (
CatalogSerializer, CourseRunSerializer, ContainedCoursesSerializer, ImageSerializer, CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer, ImageSerializer,
SubjectSerializer, PrerequisiteSerializer, VideoSerializer, OrganizationSerializer, SeatSerializer, SubjectSerializer, PrerequisiteSerializer, VideoSerializer, OrganizationSerializer, SeatSerializer,
PersonSerializer, AffiliateWindowSerializer, ContainedCourseRunsSerializer, CourseRunSearchSerializer, PersonSerializer, AffiliateWindowSerializer, ContainedCourseRunsSerializer, CourseRunSearchSerializer,
ProgramSerializer, ProgramSearchSerializer, ProgramCourseSerializer, NestedProgramSerializer, ProgramSerializer, ProgramSearchSerializer, ProgramCourseSerializer, NestedProgramSerializer,
CourseRunWithProgramsSerializer, CourseWithProgramsSerializer, CorporateEndorsementSerializer, CourseRunWithProgramsSerializer, CourseWithProgramsSerializer, CorporateEndorsementSerializer,
FAQSerializer, EndorsementSerializer, PositionSerializer, FlattenedCourseRunWithCourseSerializer, FAQSerializer, EndorsementSerializer, PositionSerializer, FlattenedCourseRunWithCourseSerializer
MinimalCourseSerializer, MinimalOrganizationSerializer, MinimalCourseRunSerializer, MinimalProgramSerializer,
CourseSerializer
) )
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import User
...@@ -32,7 +31,7 @@ from course_discovery.apps.course_metadata.tests.factories import ( ...@@ -32,7 +31,7 @@ from course_discovery.apps.course_metadata.tests.factories import (
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint:disable=no-member, test-inherits-tests # pylint:disable=no-member
def json_date_format(datetime_obj): def json_date_format(datetime_obj):
return datetime.strftime(datetime_obj, "%Y-%m-%dT%H:%M:%S.%fZ") return datetime.strftime(datetime_obj, "%Y-%m-%dT%H:%M:%S.%fZ")
...@@ -94,36 +93,19 @@ class CatalogSerializerTests(TestCase): ...@@ -94,36 +93,19 @@ class CatalogSerializerTests(TestCase):
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
class MinimalCourseSerializerTests(TestCase): class CourseSerializerTests(TestCase):
serializer_class = MinimalCourseSerializer
def get_expected_data(self, course, request):
context = {'request': request}
return {
'key': course.key,
'uuid': str(course.uuid),
'title': course.title,
'course_runs': MinimalCourseRunSerializer(course.course_runs, many=True, context=context).data,
'owners': MinimalOrganizationSerializer(course.authoring_organizations, many=True, context=context).data,
}
def test_data(self): def test_data(self):
request = make_request() course = CourseFactory()
organizations = OrganizationFactory() video = course.video
course = CourseFactory(authoring_organizations=[organizations])
CourseRunFactory.create_batch(2, course=course)
serializer = self.serializer_class(course, context={'request': request})
expected = self.get_expected_data(course, request)
self.assertDictEqual(serializer.data, expected)
request = make_request()
class CourseSerializerTests(MinimalCourseSerializerTests): CourseRunFactory.create_batch(3, course=course)
serializer_class = CourseSerializer serializer = CourseWithProgramsSerializer(course, context={'request': request})
def get_expected_data(self, course, request): expected = {
expected = super().get_expected_data(course, request) 'key': course.key,
expected.update({ 'title': course.title,
'short_description': course.short_description, 'short_description': course.short_description,
'full_description': course.full_description, 'full_description': course.full_description,
'level_type': course.level_type.name, 'level_type': course.level_type.name,
...@@ -131,9 +113,11 @@ class CourseSerializerTests(MinimalCourseSerializerTests): ...@@ -131,9 +113,11 @@ class CourseSerializerTests(MinimalCourseSerializerTests):
'prerequisites': [], 'prerequisites': [],
'expected_learning_items': [], 'expected_learning_items': [],
'image': ImageField().to_representation(course.card_image_url), 'image': ImageField().to_representation(course.card_image_url),
'video': VideoSerializer(course.video).data, 'video': VideoSerializer(video).data,
'owners': OrganizationSerializer(course.authoring_organizations, many=True).data,
'sponsors': OrganizationSerializer(course.sponsoring_organizations, many=True).data, 'sponsors': OrganizationSerializer(course.sponsoring_organizations, many=True).data,
'modified': json_date_format(course.modified), # pylint: disable=no-member 'modified': json_date_format(course.modified), # pylint: disable=no-member
'course_runs': CourseRunSerializer(course.course_runs, many=True, context={'request': request}).data,
'marketing_url': '{url}?{params}'.format( 'marketing_url': '{url}?{params}'.format(
url=course.marketing_url, url=course.marketing_url,
params=urlencode({ params=urlencode({
...@@ -141,47 +125,22 @@ class CourseSerializerTests(MinimalCourseSerializerTests): ...@@ -141,47 +125,22 @@ class CourseSerializerTests(MinimalCourseSerializerTests):
'utm_medium': request.user.referral_tracking_id, 'utm_medium': request.user.referral_tracking_id,
}) })
), ),
'course_runs': CourseRunSerializer(course.course_runs, many=True, context={'request': request}).data,
'owners': OrganizationSerializer(course.authoring_organizations, many=True).data,
})
return expected
class CourseWithProgramsSerializerTests(CourseSerializerTests): # pylint: disable=test-inherits-tests
serializer_class = CourseWithProgramsSerializer
def get_expected_data(self, course, request):
expected = super().get_expected_data(course, request)
expected.update({
'programs': NestedProgramSerializer(course.programs, many=True, context={'request': request}).data, 'programs': NestedProgramSerializer(course.programs, many=True, context={'request': request}).data,
}) }
return expected
class MinimalCourseRunSerializerTests(TestCase): self.assertDictEqual(serializer.data, expected)
serializer_class = MinimalCourseRunSerializer
def get_expected_data(self, course_run, request): # pylint: disable=unused-argument
return {
'key': course_run.key,
'uuid': str(course_run.uuid),
'title': course_run.title,
}
class CourseRunSerializerTests(TestCase):
def test_data(self): def test_data(self):
request = make_request() request = make_request()
course_run = CourseRunFactory() course_run = CourseRunFactory()
serializer = self.serializer_class(course_run, context={'request': request}) course = course_run.course
expected = self.get_expected_data(course_run, request) video = course_run.video
self.assertDictEqual(serializer.data, expected) serializer = CourseRunSerializer(course_run, context={'request': request})
ProgramFactory(courses=[course])
class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disable=test-inherits-tests
serializer_class = CourseRunSerializer
def get_expected_data(self, course_run, request): expected = {
expected = super().get_expected_data(course_run, request)
expected.update({
'course': course_run.course.key, 'course': course_run.course.key,
'key': course_run.key, 'key': course_run.key,
'title': course_run.title, # pylint: disable=no-member 'title': course_run.title, # pylint: disable=no-member
...@@ -193,7 +152,7 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa ...@@ -193,7 +152,7 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa
'enrollment_end': json_date_format(course_run.enrollment_end), 'enrollment_end': json_date_format(course_run.enrollment_end),
'announcement': json_date_format(course_run.announcement), 'announcement': json_date_format(course_run.announcement),
'image': ImageField().to_representation(course_run.card_image_url), 'image': ImageField().to_representation(course_run.card_image_url),
'video': VideoSerializer(course_run.video).data, 'video': VideoSerializer(video).data,
'pacing_type': course_run.pacing_type, 'pacing_type': course_run.pacing_type,
'content_language': course_run.language.code, 'content_language': course_run.language.code,
'transcript_languages': [], 'transcript_languages': [],
...@@ -212,8 +171,9 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa ...@@ -212,8 +171,9 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa
), ),
'level_type': course_run.level_type.name, 'level_type': course_run.level_type.name,
'availability': course_run.availability, 'availability': course_run.availability,
}) }
return expected
self.assertDictEqual(serializer.data, expected)
class CourseRunWithProgramsSerializerTests(TestCase): class CourseRunWithProgramsSerializerTests(TestCase):
...@@ -232,6 +192,7 @@ class CourseRunWithProgramsSerializerTests(TestCase): ...@@ -232,6 +192,7 @@ class CourseRunWithProgramsSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
@unittest.skip('This test is disabled until we can determine why assertDictEqual() fails for two equivalent inputs.')
class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover
def serialize_seats(self, course_run): def serialize_seats(self, course_run):
seats = { seats = {
...@@ -281,7 +242,7 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover ...@@ -281,7 +242,7 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover
def get_expected_data(self, request, course_run): def get_expected_data(self, request, course_run):
course = course_run.course course = course_run.course
serializer_context = {'request': request} serializer_context = {'request': request}
expected = dict(CourseRunSerializer(course_run, context=serializer_context).data) expected = CourseRunSerializer(course_run, context=serializer_context).data
expected.update({ expected.update({
'subjects': self.serialize_items(course.subjects.all(), 'name'), 'subjects': self.serialize_items(course.subjects.all(), 'name'),
'seats': self.serialize_seats(course_run), 'seats': self.serialize_seats(course_run),
...@@ -322,6 +283,7 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover ...@@ -322,6 +283,7 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
@ddt.ddt
class ProgramCourseSerializerTests(TestCase): class ProgramCourseSerializerTests(TestCase):
def setUp(self): def setUp(self):
super(ProgramCourseSerializerTests, self).setUp() super(ProgramCourseSerializerTests, self).setUp()
...@@ -336,10 +298,10 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -336,10 +298,10 @@ class ProgramCourseSerializerTests(TestCase):
serializer = ProgramCourseSerializer( serializer = ProgramCourseSerializer(
self.course_list, self.course_list,
many=True, many=True,
context={'request': self.request, 'program': self.program} context={'request': self.request, 'program': self.program, 'course_runs': self.program.course_runs}
) )
expected = MinimalCourseSerializer(self.course_list, many=True, context={'request': self.request}).data expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
self.assertSequenceEqual(serializer.data, expected) self.assertSequenceEqual(serializer.data, expected)
...@@ -349,10 +311,10 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -349,10 +311,10 @@ class ProgramCourseSerializerTests(TestCase):
serializer = ProgramCourseSerializer( serializer = ProgramCourseSerializer(
self.course_list, self.course_list,
many=True, many=True,
context={'request': self.request, 'program': self.program} context={'request': self.request, 'program': self.program, 'course_runs': self.program.course_runs}
) )
expected = MinimalCourseSerializer(self.course_list, many=True, context={'request': self.request}).data expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
self.assertSequenceEqual(serializer.data, expected) self.assertSequenceEqual(serializer.data, expected)
...@@ -366,138 +328,326 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -366,138 +328,326 @@ class ProgramCourseSerializerTests(TestCase):
excluded_runs.append(course_runs[0]) excluded_runs.append(course_runs[0])
program = ProgramFactory(courses=[course], excluded_course_runs=excluded_runs) program = ProgramFactory(courses=[course], excluded_course_runs=excluded_runs)
serializer_context = {'request': self.request, 'program': program} serializer_context = {'request': self.request, 'program': program, 'course_runs': program.course_runs}
serializer = ProgramCourseSerializer(course, context=serializer_context) serializer = ProgramCourseSerializer(course, context=serializer_context)
expected = MinimalCourseSerializer(course, context=serializer_context).data expected = CourseSerializer(course, context=serializer_context).data
expected['course_runs'] = MinimalCourseRunSerializer([course_runs[1]], many=True, expected['course_runs'] = CourseRunSerializer([course_runs[1]], many=True,
context={'request': self.request}).data context={'request': self.request}).data
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
def test_with_published_course_runs_only_context(self): @ddt.data(
""" Verify setting the published_course_runs_only context value excludes unpublished course runs. """ [CourseRunStatus.Unpublished, 1],
# Create a program and course. The course should have both published and un-published course runs. [CourseRunStatus.Unpublished, 0],
course = CourseFactory() [CourseRunStatus.Published, 1],
courses = [course] [CourseRunStatus.Published, 0]
program = ProgramFactory(courses=courses) )
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=course) @ddt.unpack
CourseRunFactory(status=CourseRunStatus.Published, course=course) def test_with_published_only_querystring(self, course_run_status, published_course_runs_only):
"""
# We do NOT expect the results to included the unpublished data Test the serializer's ability to filter out course_runs based on
expected = MinimalCourseSerializer(courses, many=True, context={'request': self.request}).data "published_course_runs_only" query string
expected[0]['course_runs'] = [course_run for course_run in expected[0]['course_runs'] if """
course_run['uuid'] != str(unpublished_course_run.uuid)] expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
self.assertEqual(len(expected[0]['course_runs']), 1)
for course in self.course_list:
CourseRunFactory.create_batch(2, status=course_run_status, course=course)
serializer = ProgramCourseSerializer( serializer = ProgramCourseSerializer(
courses, self.course_list,
many=True, many=True,
context={ context={
'request': self.request, 'request': self.request,
'program': program, 'program': self.program,
'published_course_runs_only': True, 'published_course_runs_only': published_course_runs_only,
'course_runs': self.program.course_runs
} }
) )
validate_data = serializer.data
self.assertSequenceEqual(serializer.data, expected) if not published_course_runs_only or course_run_status != CourseRunStatus.Unpublished:
expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
self.assertSequenceEqual(validate_data, expected)
class MinimalProgramSerializerTests(TestCase):
serializer_class = MinimalProgramSerializer
def create_program(self):
organizations = OrganizationFactory.create_batch(2)
person = PersonFactory()
courses = CourseFactory.create_batch(3)
for course in courses:
CourseRunFactory.create_batch(2, course=course, staff=[person])
class ProgramSerializerTests(TestCase):
def test_data(self):
request = make_request()
org_list = OrganizationFactory.create_batch(1)
course_list = CourseFactory.create_batch(3)
for course in course_list:
CourseRunFactory.create_batch(
3,
course=course,
enrollment_start=datetime(2014, 1, 1),
start=datetime(2014, 1, 1)
)
corporate_endorsements = CorporateEndorsementFactory.create_batch(1)
individual_endorsements = EndorsementFactory.create_batch(1)
staff = PersonFactory.create_batch(1)
job_outlook_items = JobOutlookItemFactory.create_batch(1)
expected_learning_items = ExpectedLearningItemFactory.create_batch(1)
program = ProgramFactory( program = ProgramFactory(
courses=courses, authoring_organizations=org_list,
authoring_organizations=organizations, courses=course_list,
credit_backing_organizations=organizations, credit_backing_organizations=org_list,
corporate_endorsements=CorporateEndorsementFactory.create_batch(1), corporate_endorsements=corporate_endorsements,
individual_endorsements=EndorsementFactory.create_batch(1), individual_endorsements=individual_endorsements,
expected_learning_items=ExpectedLearningItemFactory.create_batch(1), expected_learning_items=expected_learning_items,
job_outlook_items=JobOutlookItemFactory.create_batch(1), staff=staff,
banner_image=make_image_file('test_banner.jpg'), job_outlook_items=job_outlook_items,
video=VideoFactory()
) )
return program program.banner_image = make_image_file('test_banner.jpg')
program.save()
def get_expected_data(self, program, request): serializer = ProgramSerializer(program, context={'request': request})
image_field = StdImageSerializerField() expected_banner_image_urls = {
image_field._context = {'request': request} # pylint: disable=protected-access size_key: {
'url': '{}{}'.format(
'http://testserver',
getattr(program.banner_image, size_key).url
),
'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height']
}
for size_key in program.banner_image.field.variations
}
return { expected = {
'uuid': str(program.uuid), 'uuid': str(program.uuid),
'title': program.title, 'title': program.title,
'subtitle': program.subtitle, 'subtitle': program.subtitle,
'type': program.type.name, 'type': program.type.name,
'status': program.status,
'marketing_slug': program.marketing_slug, 'marketing_slug': program.marketing_slug,
'marketing_url': program.marketing_url, 'marketing_url': program.marketing_url,
'banner_image': image_field.to_representation(program.banner_image),
'courses': ProgramCourseSerializer(program.courses, many=True,
context={'request': request, 'program': program}).data,
'authoring_organizations': MinimalOrganizationSerializer(program.authoring_organizations, many=True).data,
'card_image_url': program.card_image_url, 'card_image_url': program.card_image_url,
'banner_image_url': program.banner_image_url,
'video': None,
'banner_image': expected_banner_image_urls,
'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'credit_redemption_overview': program.credit_redemption_overview,
'courses': ProgramCourseSerializer(
program.courses,
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data,
'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data,
'credit_backing_organizations': OrganizationSerializer(
program.credit_backing_organizations,
many=True
).data,
'expected_learning_items': [item.value for item in program.expected_learning_items.all()],
'faq': FAQSerializer(program.faq, many=True).data,
'individual_endorsements': EndorsementSerializer(program.individual_endorsements, many=True).data,
'staff': PersonSerializer(program.staff, many=True).data,
'job_outlook_items': [item.value for item in program.job_outlook_items.all()],
'languages': [serialize_language_to_code(l) for l in program.languages],
'weeks_to_complete': program.weeks_to_complete,
'max_hours_effort_per_week': None,
'min_hours_effort_per_week': None,
'overview': None,
'price_ranges': [],
'status': program.status,
'subjects': [],
'transcript_languages': [],
} }
def test_data(self):
request = make_request()
program = self.create_program()
serializer = self.serializer_class(program, context={'request': request})
expected = self.get_expected_data(program, request)
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
def test_with_exclusions(self):
"""
Verify we can specify program excluded_course_runs and the serializers will
render the course_runs with exclusions
"""
request = make_request()
org_list = OrganizationFactory.create_batch(1)
course_list = CourseFactory.create_batch(4)
excluded_runs = []
for course in course_list:
course_runs = CourseRunFactory.create_batch(
3,
course=course,
enrollment_start=datetime(2014, 1, 1),
start=datetime(2014, 1, 1)
)
excluded_runs.append(course_runs[0])
class ProgramSerializerTests(MinimalProgramSerializerTests): # pylint: disable=test-inherits-tests program = ProgramFactory(
serializer_class = ProgramSerializer authoring_organizations=org_list,
courses=course_list,
excluded_course_runs=excluded_runs
)
serializer = ProgramSerializer(program, context={'request': request})
def get_expected_data(self, program, request): expected = {
expected = super().get_expected_data(program, request) 'uuid': str(program.uuid),
expected.update({ 'title': program.title,
'subtitle': program.subtitle,
'type': program.type.name,
'marketing_slug': program.marketing_slug, 'marketing_slug': program.marketing_slug,
'marketing_url': program.marketing_url, 'marketing_url': program.marketing_url,
'video': VideoSerializer(program.video).data, 'card_image_url': program.card_image_url,
'banner_image': {},
'banner_image_url': program.banner_image_url,
'video': None,
'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'credit_redemption_overview': program.credit_redemption_overview, 'credit_redemption_overview': program.credit_redemption_overview,
'courses': ProgramCourseSerializer(
program.courses,
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data,
'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data, 'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data,
'credit_backing_organizations': OrganizationSerializer( 'credit_backing_organizations': OrganizationSerializer(
program.credit_backing_organizations, program.credit_backing_organizations,
many=True many=True
).data, ).data,
'expected_learning_items': [item.value for item in program.expected_learning_items.all()], 'expected_learning_items': [],
'faq': FAQSerializer(program.faq, many=True).data, 'faq': FAQSerializer(program.faq, many=True).data,
'individual_endorsements': EndorsementSerializer(program.individual_endorsements, many=True).data, 'individual_endorsements': EndorsementSerializer(program.individual_endorsements, many=True).data,
'staff': PersonSerializer(program.staff, many=True).data, 'staff': PersonSerializer(program.staff, many=True).data,
'job_outlook_items': [item.value for item in program.job_outlook_items.all()], 'job_outlook_items': [],
'languages': [serialize_language_to_code(l) for l in program.languages], 'languages': [serialize_language_to_code(l) for l in program.languages],
'weeks_to_complete': program.weeks_to_complete, 'weeks_to_complete': program.weeks_to_complete,
'max_hours_effort_per_week': program.max_hours_effort_per_week, 'max_hours_effort_per_week': None,
'min_hours_effort_per_week': program.min_hours_effort_per_week, 'min_hours_effort_per_week': None,
'overview': program.overview, 'overview': None,
'price_ranges': [], 'price_ranges': [],
'subjects': SubjectSerializer(program.subjects, many=True).data, 'status': program.status,
'transcript_languages': [serialize_language_to_code(l) for l in program.transcript_languages], 'subjects': [],
}) 'transcript_languages': [],
return expected }
self.assertDictEqual(serializer.data, expected)
def test_data_with_exclusions(self): def test_course_ordering(self):
""" """
Verify we can specify program excluded_course_runs and the serializers will Verify that courses in a program are ordered by ascending run start date,
render the course_runs with exclusions with ties broken by earliest run enrollment start date.
""" """
request = make_request() request = make_request()
program = self.create_program() course_list = CourseFactory.create_batch(3)
# Create a course run with arbitrary start and empty enrollment_start.
CourseRunFactory(
course=course_list[2],
enrollment_start=None,
start=datetime(2014, 2, 1),
)
# Create a second run with matching start, but later enrollment_start.
CourseRunFactory(
course=course_list[1],
enrollment_start=datetime(2014, 1, 2),
start=datetime(2014, 2, 1),
)
excluded_course_run = program.courses.all()[0].course_runs.all()[0] # Create a third run with later start and enrollment_start.
program.excluded_course_runs.add(excluded_course_run) CourseRunFactory(
course=course_list[0],
enrollment_start=datetime(2014, 2, 1),
start=datetime(2014, 3, 1),
)
expected = self.get_expected_data(program, request) program = ProgramFactory(courses=course_list)
serializer = ProgramSerializer(program, context={'request': request}) serializer = ProgramSerializer(program, context={'request': request})
self.assertDictEqual(serializer.data, expected)
expected = ProgramCourseSerializer(
# The expected ordering is the reverse of course_list.
course_list[::-1],
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data
self.assertEqual(serializer.data['courses'], expected)
def test_course_ordering_with_exclusions(self):
"""
Verify that excluded course runs aren't used when ordering courses.
"""
request = make_request()
course_list = CourseFactory.create_batch(3)
# Create a course run with arbitrary start and empty enrollment_start.
# This run will be excluded from the program. If it wasn't excluded,
# the expected course ordering, by index, would be: 0, 2, 1.
excluded_run = CourseRunFactory(
course=course_list[0],
enrollment_start=None,
start=datetime(2014, 1, 1),
)
# Create a run with later start and empty enrollment_start.
CourseRunFactory(
course=course_list[2],
enrollment_start=None,
start=datetime(2014, 2, 1),
)
# Create a run with matching start, but later enrollment_start.
CourseRunFactory(
course=course_list[1],
enrollment_start=datetime(2014, 1, 2),
start=datetime(2014, 2, 1),
)
# Create a run with later start and enrollment_start.
CourseRunFactory(
course=course_list[0],
enrollment_start=datetime(2014, 2, 1),
start=datetime(2014, 3, 1),
)
program = ProgramFactory(courses=course_list, excluded_course_runs=[excluded_run])
serializer = ProgramSerializer(program, context={'request': request})
expected = ProgramCourseSerializer(
# The expected ordering is the reverse of course_list.
course_list[::-1],
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data
self.assertEqual(serializer.data['courses'], expected)
def test_course_ordering_with_no_start(self):
"""
Verify that a courses run with missing start date appears last when ordering courses.
"""
request = make_request()
course_list = CourseFactory.create_batch(3)
# Create a course run with arbitrary start and empty enrollment_start.
CourseRunFactory(
course=course_list[2],
enrollment_start=None,
start=datetime(2014, 2, 1),
)
# Create a second run with matching start, but later enrollment_start.
CourseRunFactory(
course=course_list[1],
enrollment_start=datetime(2014, 1, 2),
start=datetime(2014, 2, 1),
)
# Create a third run with empty start and enrollment_start.
CourseRunFactory(
course=course_list[0],
enrollment_start=None,
start=None,
)
program = ProgramFactory(courses=course_list)
serializer = ProgramSerializer(program, context={'request': request})
expected = ProgramCourseSerializer(
# The expected ordering is the reverse of course_list.
course_list[::-1],
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data
self.assertEqual(serializer.data['courses'], expected)
class ContainedCourseRunsSerializerTests(TestCase): class ContainedCourseRunsSerializerTests(TestCase):
...@@ -622,45 +772,24 @@ class VideoSerializerTests(TestCase): ...@@ -622,45 +772,24 @@ class VideoSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
class MinimalOrganizationSerializerTests(TestCase): class OrganizationSerializerTests(TestCase):
serializer_class = MinimalOrganizationSerializer def test_data(self):
organization = OrganizationFactory()
def create_organization(self): TAG = 'test'
return OrganizationFactory() organization.tags.add(TAG)
serializer = OrganizationSerializer(organization)
def get_expected_data(self, organization): expected = {
return {
'uuid': str(organization.uuid),
'key': organization.key, 'key': organization.key,
'name': organization.name, 'name': organization.name,
}
def test_data(self):
organization = self.create_organization()
serializer = self.serializer_class(organization)
expected = self.get_expected_data(organization)
self.assertDictEqual(serializer.data, expected)
class OrganizationSerializerTests(MinimalOrganizationSerializerTests):
TAG = 'test-tag'
serializer_class = OrganizationSerializer
def create_organization(self):
organization = super().create_organization()
organization.tags.add(self.TAG)
return organization
def get_expected_data(self, organization):
expected = super().get_expected_data(organization)
expected.update({
'description': organization.description, 'description': organization.description,
'homepage_url': organization.homepage_url, 'homepage_url': organization.homepage_url,
'logo_image_url': organization.logo_image_url, 'logo_image_url': organization.logo_image_url,
'tags': [self.TAG], 'tags': [TAG],
'marketing_url': organization.marketing_url, 'marketing_url': organization.marketing_url,
}) }
return expected
self.assertDictEqual(serializer.data, expected)
class SeatSerializerTests(TestCase): class SeatSerializerTests(TestCase):
......
...@@ -141,7 +141,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -141,7 +141,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2') CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2')
CourseRunFactory(enrollment_end=enrollment_end, course=self.course) CourseRunFactory(enrollment_end=enrollment_end, course=self.course)
with self.assertNumQueries(40): with self.assertNumQueries(41):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True))
......
...@@ -19,7 +19,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -19,7 +19,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(18): with self.assertNumQueries(19):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.data, self.serialize_course(self.course))
...@@ -28,7 +28,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -28,7 +28,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list') url = reverse('api:v1:course-list')
with self.assertNumQueries(24): with self.assertNumQueries(25):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual( self.assertListEqual(
...@@ -55,6 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -55,6 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses]) keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(35): with self.assertNumQueries(38):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
...@@ -40,17 +40,15 @@ class ProgramViewSetTests(APITestCase): ...@@ -40,17 +40,15 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = ProgramFactory() program = ProgramFactory()
with self.assertNumQueries(34): self.assert_retrieve_success(program)
self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self): def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """ """ Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory() course = CourseFactory()
program = ProgramFactory(courses=[course]) program = ProgramFactory(courses=[course])
with self.assertNumQueries(49): self.assert_retrieve_success(program)
self.assert_retrieve_success(program)
def assert_list_results(self, url, expected, expected_query_count): def assert_list_results(self, url, expected):
""" """
Asserts the results serialized/returned at the URL matches those that are expected. Asserts the results serialized/returned at the URL matches those that are expected.
Args: Args:
...@@ -64,9 +62,7 @@ class ProgramViewSetTests(APITestCase): ...@@ -64,9 +62,7 @@ class ProgramViewSetTests(APITestCase):
Returns: Returns:
None None
""" """
with self.assertNumQueries(expected_query_count): response = self.client.get(url)
response = self.client.get(url)
self.assertEqual( self.assertEqual(
response.data['results'], response.data['results'],
ProgramSerializer(expected, many=True, context={'request': self.request}).data ProgramSerializer(expected, many=True, context={'request': self.request}).data
...@@ -76,17 +72,17 @@ class ProgramViewSetTests(APITestCase): ...@@ -76,17 +72,17 @@ class ProgramViewSetTests(APITestCase):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3) expected = ProgramFactory.create_batch(3)
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected, 40) self.assert_list_results(self.list_path, expected)
def test_filter_by_type(self): def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """ """ Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo' program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name) program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program], 18) self.assert_list_results(url, [program])
url = self.list_path + '?type=bar' url = self.list_path + '?type=bar'
self.assert_list_results(url, [], 4) self.assert_list_results(url, [])
def test_filter_by_uuids(self): def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """ """ Verify that the endpoint filters programs to those matching the provided UUIDs. """
...@@ -98,14 +94,14 @@ class ProgramViewSetTests(APITestCase): ...@@ -98,14 +94,14 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out. # Create a third program, which should be filtered out.
ProgramFactory() ProgramFactory()
self.assert_list_results(url, expected, 29) self.assert_list_results(url, expected)
@ddt.data( @ddt.data(
(ProgramStatus.Unpublished, False, 4), (ProgramStatus.Unpublished, False),
(ProgramStatus.Active, True, 40), (ProgramStatus.Active, True),
) )
@ddt.unpack @ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count): def test_filter_by_marketable(self, status, is_marketable):
""" Verify the endpoint filters programs to those that are marketable. """ """ Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1' url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='') ProgramFactory(marketing_slug='')
...@@ -114,4 +110,4 @@ class ProgramViewSetTests(APITestCase): ...@@ -114,4 +110,4 @@ class ProgramViewSetTests(APITestCase):
expected = programs if is_marketable else [] expected = programs if is_marketable else []
self.assertEqual(list(Program.objects.marketable()), expected) self.assertEqual(list(Program.objects.marketable()), expected)
self.assert_list_results(url, expected, expected_query_count) self.assert_list_results(url, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment