Commit 35dc00f7 by Renzo Lucioni Committed by GitHub

Merge pull request #347 from edx/renzo/query-optimization

Query optimization
parents 54feb857 bc3d7d01
# pylint: disable=abstract-method # pylint: disable=abstract-method
import datetime
import json import json
from urllib.parse import urlencode from urllib.parse import urlencode
import pytz
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db.models.query import Prefetch
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer
from rest_framework import serializers from rest_framework import serializers
...@@ -67,25 +70,39 @@ PROGRAM_FACET_FIELDS = BASE_PROGRAM_FIELDS + ('organizations',) ...@@ -67,25 +70,39 @@ PROGRAM_FACET_FIELDS = BASE_PROGRAM_FIELDS + ('organizations',)
PREFETCH_FIELDS = { PREFETCH_FIELDS = {
'course_run': [ 'course_run': [
'course__partner', 'course__level_type', 'course__programs', 'course__programs__type', 'course__level_type',
'course__programs__partner', 'seats', 'transcript_languages', 'seats__currency', 'staff', 'course__partner',
'staff__position', 'staff__position__organization', 'language', 'course__programs',
'course__programs__partner',
'course__programs__type',
'language',
'seats',
'seats__currency',
'staff',
'staff__position',
'staff__position__organization',
'transcript_languages',
], ],
'course': [ 'course': [
'level_type', 'video', 'programs', 'course_runs', 'subjects', 'prerequisites', 'expected_learning_items', 'authoring_organizations',
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner', 'authoring_organizations__partner',
'sponsoring_organizations', 'sponsoring_organizations__tags', 'sponsoring_organizations__partner', 'authoring_organizations__tags',
], 'course_runs',
'program': [ 'expected_learning_items',
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner', 'level_type',
'excluded_course_runs', 'courses', 'courses__authoring_organizations', 'courses__course_runs', 'prerequisites',
'programs',
'sponsoring_organizations',
'sponsoring_organizations__partner',
'sponsoring_organizations__tags',
'subjects',
'video',
], ],
} }
SELECT_RELATED_FIELDS = { SELECT_RELATED_FIELDS = {
'course': ['level_type', 'video', 'partner', ], 'course': ['level_type', 'partner', 'video'],
'course_run': ['course', 'language', 'video', ], 'course_run': ['course', 'language', 'video'],
'program': ['type', 'video', 'partner', ],
} }
...@@ -182,6 +199,10 @@ class PersonSerializer(serializers.ModelSerializer): ...@@ -182,6 +199,10 @@ class PersonSerializer(serializers.ModelSerializer):
"""Serializer for the ``Person`` model.""" """Serializer for the ``Person`` model."""
position = PositionSerializer() position = PositionSerializer()
@classmethod
def prefetch_queryset(cls):
return Person.objects.all().select_related('position__organization')
class Meta(object): class Meta(object):
model = Person model = Person
fields = ('uuid', 'given_name', 'family_name', 'bio', 'profile_image_url', 'slug', 'position') fields = ('uuid', 'given_name', 'family_name', 'bio', 'profile_image_url', 'slug', 'position')
...@@ -191,6 +212,10 @@ class EndorsementSerializer(serializers.ModelSerializer): ...@@ -191,6 +212,10 @@ class EndorsementSerializer(serializers.ModelSerializer):
"""Serializer for the ``Endorsement`` model.""" """Serializer for the ``Endorsement`` model."""
endorser = PersonSerializer() endorser = PersonSerializer()
@classmethod
def prefetch_queryset(cls):
return Endorsement.objects.all().select_related('endorser')
class Meta(object): class Meta(object):
model = Endorsement model = Endorsement
fields = ('endorser', 'quote',) fields = ('endorser', 'quote',)
...@@ -201,6 +226,12 @@ class CorporateEndorsementSerializer(serializers.ModelSerializer): ...@@ -201,6 +226,12 @@ class CorporateEndorsementSerializer(serializers.ModelSerializer):
image = ImageSerializer() image = ImageSerializer()
individual_endorsements = EndorsementSerializer(many=True) individual_endorsements = EndorsementSerializer(many=True)
@classmethod
def prefetch_queryset(cls):
return CorporateEndorsement.objects.all().select_related('image').prefetch_related(
Prefetch('endorser', queryset=EndorsementSerializer.prefetch_queryset()),
)
class Meta(object): class Meta(object):
model = CorporateEndorsement model = CorporateEndorsement
fields = ('corporation_name', 'statement', 'image', 'individual_endorsements',) fields = ('corporation_name', 'statement', 'image', 'individual_endorsements',)
...@@ -220,26 +251,26 @@ class SeatSerializer(serializers.ModelSerializer): ...@@ -220,26 +251,26 @@ class SeatSerializer(serializers.ModelSerializer):
credit_provider = serializers.CharField() credit_provider = serializers.CharField()
credit_hours = serializers.IntegerField() credit_hours = serializers.IntegerField()
@classmethod
def prefetch_queryset(cls):
return Seat.objects.all().select_related('currency')
class Meta(object): class Meta(object):
model = Seat model = Seat
fields = ('type', 'price', 'currency', 'upgrade_deadline', 'credit_provider', 'credit_hours',) fields = ('type', 'price', 'currency', 'upgrade_deadline', 'credit_provider', 'credit_hours',)
class MinimalOrganizationSerializer(serializers.ModelSerializer): class OrganizationSerializer(TaggitSerializer, serializers.ModelSerializer):
class Meta:
model = Organization
fields = ('uuid', 'key', 'name',)
class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
"""Serializer for the ``Organization`` model.""" """Serializer for the ``Organization`` model."""
tags = TagListSerializerField() tags = TagListSerializerField()
class Meta(MinimalOrganizationSerializer.Meta): @classmethod
def prefetch_queryset(cls):
return Organization.objects.all().select_related('partner').prefetch_related('tags')
class Meta(object):
model = Organization model = Organization
fields = MinimalOrganizationSerializer.Meta.fields + ( fields = ('key', 'name', 'description', 'homepage_url', 'tags', 'logo_image_url', 'marketing_url')
'description', 'homepage_url', 'tags', 'logo_image_url', 'marketing_url',
)
class CatalogSerializer(serializers.ModelSerializer): class CatalogSerializer(serializers.ModelSerializer):
...@@ -277,13 +308,7 @@ class NestedProgramSerializer(serializers.ModelSerializer): ...@@ -277,13 +308,7 @@ class NestedProgramSerializer(serializers.ModelSerializer):
read_only_fields = ('uuid', 'marketing_url',) read_only_fields = ('uuid', 'marketing_url',)
class MinimalCourseRunSerializer(TimestampModelSerializer): class CourseRunSerializer(TimestampModelSerializer):
class Meta:
model = CourseRun
fields = ('key', 'uuid', 'title',)
class CourseRunSerializer(MinimalCourseRunSerializer):
"""Serializer for the ``CourseRun`` model.""" """Serializer for the ``CourseRun`` model."""
course = serializers.SlugRelatedField(read_only=True, slug_field='key') course = serializers.SlugRelatedField(read_only=True, slug_field='key')
content_language = serializers.SlugRelatedField( content_language = serializers.SlugRelatedField(
...@@ -299,13 +324,21 @@ class CourseRunSerializer(MinimalCourseRunSerializer): ...@@ -299,13 +324,21 @@ class CourseRunSerializer(MinimalCourseRunSerializer):
marketing_url = serializers.SerializerMethodField() marketing_url = serializers.SerializerMethodField()
level_type = serializers.SlugRelatedField(read_only=True, slug_field='name') level_type = serializers.SlugRelatedField(read_only=True, slug_field='name')
class Meta(MinimalCourseRunSerializer.Meta): @classmethod
def prefetch_queryset(cls):
return CourseRun.objects.all().select_related('course', 'language', 'video').prefetch_related(
'transcript_languages',
Prefetch('seats', queryset=SeatSerializer.prefetch_queryset()),
Prefetch('staff', queryset=PersonSerializer.prefetch_queryset()),
)
class Meta:
model = CourseRun model = CourseRun
fields = MinimalCourseRunSerializer.Meta.fields + ( fields = (
'course', 'short_description', 'full_description', 'start', 'end', 'enrollment_start', 'enrollment_end', 'course', 'key', 'title', 'short_description', 'full_description', 'start', 'end',
'announcement', 'image', 'video', 'seats', 'content_language', 'transcript_languages', 'instructors', 'enrollment_start', 'enrollment_end', 'announcement', 'image', 'video', 'seats',
'staff', 'pacing_type', 'min_effort', 'max_effort', 'modified', 'marketing_url', 'level_type', 'content_language', 'transcript_languages', 'instructors', 'staff',
'availability', 'pacing_type', 'min_effort', 'max_effort', 'modified', 'marketing_url', 'level_type', 'availability',
) )
def get_marketing_url(self, obj): def get_marketing_url(self, obj):
...@@ -332,16 +365,7 @@ class ContainedCourseRunsSerializer(serializers.Serializer): ...@@ -332,16 +365,7 @@ class ContainedCourseRunsSerializer(serializers.Serializer):
) )
class MinimalCourseSerializer(TimestampModelSerializer): class CourseSerializer(TimestampModelSerializer):
course_runs = MinimalCourseRunSerializer(many=True)
owners = MinimalOrganizationSerializer(many=True, source='authoring_organizations')
class Meta:
model = Course
fields = ('key', 'uuid', 'title', 'course_runs', 'owners',)
class CourseSerializer(MinimalCourseSerializer):
"""Serializer for the ``Course`` model.""" """Serializer for the ``Course`` model."""
level_type = serializers.SlugRelatedField(read_only=True, slug_field='name') level_type = serializers.SlugRelatedField(read_only=True, slug_field='name')
subjects = SubjectSerializer(many=True) subjects = SubjectSerializer(many=True)
...@@ -354,11 +378,23 @@ class CourseSerializer(MinimalCourseSerializer): ...@@ -354,11 +378,23 @@ class CourseSerializer(MinimalCourseSerializer):
course_runs = CourseRunSerializer(many=True) course_runs = CourseRunSerializer(many=True)
marketing_url = serializers.SerializerMethodField() marketing_url = serializers.SerializerMethodField()
class Meta(MinimalCourseSerializer.Meta): @classmethod
def prefetch_queryset(cls):
return Course.objects.all().select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
class Meta:
model = Course model = Course
fields = MinimalCourseSerializer.Meta.fields + ( fields = (
'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites', 'key', 'title', 'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites',
'expected_learning_items', 'image', 'video', 'sponsors', 'modified', 'marketing_url', 'expected_learning_items', 'image', 'video', 'owners', 'sponsors', 'modified', 'course_runs',
'marketing_url',
) )
def get_marketing_url(self, obj): def get_marketing_url(self, obj):
...@@ -390,55 +426,29 @@ class ContainedCoursesSerializer(serializers.Serializer): ...@@ -390,55 +426,29 @@ class ContainedCoursesSerializer(serializers.Serializer):
) )
class ProgramCourseSerializer(MinimalCourseSerializer): class ProgramCourseSerializer(CourseSerializer):
"""Serializer used to filter out excluded course runs in a course associated with the program""" """Serializer used to filter out excluded course runs in a course associated with the program"""
course_runs = serializers.SerializerMethodField() course_runs = serializers.SerializerMethodField()
def get_course_runs(self, course): def get_course_runs(self, course):
program = self.context['program'] course_runs = self.context['course_runs']
course_runs = list(course.course_runs.all()) course_runs = [course_run for course_run in course_runs if course_run.course == course]
excluded_course_runs = list(program.excluded_course_runs.all())
course_runs = [course_run for course_run in course_runs if course_run not in excluded_course_runs]
if self.context.get('published_course_runs_only'): if self.context.get('published_course_runs_only'):
course_runs = [course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published] course_runs = [course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published]
return MinimalCourseRunSerializer( return CourseRunSerializer(
course_runs, course_runs,
many=True, many=True,
context={'request': self.context.get('request')} context={'request': self.context.get('request')}
).data ).data
class MinimalProgramSerializer(serializers.ModelSerializer): class ProgramSerializer(serializers.ModelSerializer):
authoring_organizations = MinimalOrganizationSerializer(many=True)
banner_image = StdImageSerializerField()
courses = serializers.SerializerMethodField() courses = serializers.SerializerMethodField()
authoring_organizations = OrganizationSerializer(many=True)
type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all()) type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all())
banner_image = StdImageSerializerField()
def get_courses(self, program):
course_serializer = ProgramCourseSerializer(
program.courses.all(),
many=True,
context={
'request': self.context.get('request'),
'program': program,
'published_course_runs_only': self.context.get('published_course_runs_only'),
}
)
return course_serializer.data
class Meta:
model = Program
fields = (
'uuid', 'title', 'subtitle', 'type', 'status', 'marketing_slug', 'marketing_url', 'banner_image', 'courses',
'authoring_organizations', 'card_image_url',
)
read_only_fields = ('uuid', 'marketing_url', 'banner_image')
class ProgramSerializer(MinimalProgramSerializer):
video = VideoSerializer() video = VideoSerializer()
expected_learning_items = serializers.SlugRelatedField(many=True, read_only=True, slug_field='value') expected_learning_items = serializers.SlugRelatedField(many=True, read_only=True, slug_field='value')
faq = FAQSerializer(many=True) faq = FAQSerializer(many=True)
...@@ -457,14 +467,113 @@ class ProgramSerializer(MinimalProgramSerializer): ...@@ -457,14 +467,113 @@ class ProgramSerializer(MinimalProgramSerializer):
subjects = SubjectSerializer(many=True) subjects = SubjectSerializer(many=True)
staff = PersonSerializer(many=True) staff = PersonSerializer(many=True)
class Meta(MinimalProgramSerializer.Meta): @classmethod
def prefetch_queryset(cls):
"""
Prefetch the related objects that will be serialized with a `Program`.
We use Pefetch objects so that we can prefetch and select all the way down the
chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching).
"""
return Program.objects.all().select_related('type', 'video', 'partner').prefetch_related(
'excluded_course_runs',
'expected_learning_items',
'faq',
'job_outlook_items',
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
# `ProgramType` on `Program`.
'type__applicable_seat_types',
Prefetch('courses', queryset=ProgramCourseSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()),
)
def get_courses(self, program):
courses, course_runs = self.sort_courses(program)
course_serializer = ProgramCourseSerializer(
courses,
many=True,
context={
'request': self.context.get('request'),
'program': program,
'published_course_runs_only': self.context.get('published_course_runs_only'),
'course_runs': course_runs,
}
)
return course_serializer.data
def sort_courses(self, program):
"""
Sorting by enrollment start then by course start yields a list ordered by course start, with
ties broken by enrollment start. This works because Python sorting is stable: two objects with
equal keys appear in the same order in sorted output as they appear in the input.
Courses are only created if there's at least one course run belonging to that course, so
course_runs should never be empty. If it is, key functions in this method attempting to find the
min of an empty sequence will raise a ValueError.
"""
course_runs = list(program.course_runs)
def min_run_enrollment_start(course):
# Enrollment starts may be empty. When this is the case, we make the same assumption as
# the LMS: no enrollment_start is equivalent to (offset-aware) datetime.datetime.min.
min_datetime = datetime.datetime.min.replace(tzinfo=pytz.UTC)
# Course runs excluded from the program are excluded here, too.
#
# If this becomes a candidate for optimization in the future, be careful sorting null values
# in the database. PostgreSQL and MySQL sort null values as if they are higher than non-null
# values, while SQLite does the opposite.
#
# For more, refer to https://docs.djangoproject.com/en/1.10/ref/models/querysets/#latest.
_course_runs = [course_run for course_run in course_runs if course_run.course == course]
# Return early if we have no course runs since min() will fail.
if not _course_runs:
return min_datetime
run = min(_course_runs, key=lambda run: run.enrollment_start or min_datetime)
return run.enrollment_start or min_datetime
def min_run_start(course):
# Course starts may be empty. Since this means the course can't be started, missing course
# start date is equivalent to (offset-aware) datetime.datetime.max.
max_datetime = datetime.datetime.max.replace(tzinfo=pytz.UTC)
_course_runs = [course_run for course_run in course_runs if course_run.course == course]
# Return early if we have no course runs since min() will fail.
if not _course_runs:
return max_datetime
run = min(_course_runs, key=lambda run: run.start or max_datetime)
return run.start or max_datetime
courses = list(program.courses.all())
courses.sort(key=min_run_enrollment_start)
courses.sort(key=min_run_start)
return courses, course_runs
class Meta:
model = Program model = Program
fields = MinimalProgramSerializer.Meta.fields + ( fields = (
'overview', 'weeks_to_complete', 'min_hours_effort_per_week', 'max_hours_effort_per_week', 'video', 'uuid', 'title', 'subtitle', 'type', 'status', 'marketing_slug', 'marketing_url', 'courses',
'overview', 'weeks_to_complete', 'min_hours_effort_per_week', 'max_hours_effort_per_week',
'authoring_organizations', 'banner_image', 'banner_image_url', 'card_image_url', 'video',
'expected_learning_items', 'faq', 'credit_backing_organizations', 'corporate_endorsements', 'expected_learning_items', 'faq', 'credit_backing_organizations', 'corporate_endorsements',
'job_outlook_items', 'individual_endorsements', 'languages', 'transcript_languages', 'subjects', 'job_outlook_items', 'individual_endorsements', 'languages', 'transcript_languages', 'subjects',
'price_ranges', 'staff', 'credit_redemption_overview', 'price_ranges', 'staff', 'credit_redemption_overview',
) )
read_only_fields = ('uuid', 'marketing_url', 'banner_image')
class AffiliateWindowSerializer(serializers.ModelSerializer): class AffiliateWindowSerializer(serializers.ModelSerializer):
...@@ -481,7 +590,7 @@ class AffiliateWindowSerializer(serializers.ModelSerializer): ...@@ -481,7 +590,7 @@ class AffiliateWindowSerializer(serializers.ModelSerializer):
category = serializers.SerializerMethodField() category = serializers.SerializerMethodField()
price = serializers.SerializerMethodField() price = serializers.SerializerMethodField()
class Meta(object): class Meta:
model = Seat model = Seat
fields = ( fields = (
'name', 'pid', 'desc', 'category', 'purl', 'imgurl', 'price', 'currency' 'name', 'pid', 'desc', 'category', 'purl', 'imgurl', 'price', 'currency'
...@@ -509,7 +618,7 @@ class FlattenedCourseRunWithCourseSerializer(CourseRunSerializer): ...@@ -509,7 +618,7 @@ class FlattenedCourseRunWithCourseSerializer(CourseRunSerializer):
course_key = serializers.SlugRelatedField(read_only=True, source='course', slug_field='key') course_key = serializers.SlugRelatedField(read_only=True, source='course', slug_field='key')
image = ImageField(read_only=True, source='card_image_url') image = ImageField(read_only=True, source='card_image_url')
class Meta(object): class Meta:
model = CourseRun model = CourseRun
fields = ( fields = (
'key', 'title', 'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites', 'key', 'title', 'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites',
......
...@@ -9,14 +9,12 @@ from rest_framework.test import APIRequestFactory ...@@ -9,14 +9,12 @@ from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.serializers import ( from course_discovery.apps.api.serializers import (
CatalogSerializer, CourseRunSerializer, ContainedCoursesSerializer, ImageSerializer, CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer, ImageSerializer,
SubjectSerializer, PrerequisiteSerializer, VideoSerializer, OrganizationSerializer, SeatSerializer, SubjectSerializer, PrerequisiteSerializer, VideoSerializer, OrganizationSerializer, SeatSerializer,
PersonSerializer, AffiliateWindowSerializer, ContainedCourseRunsSerializer, CourseRunSearchSerializer, PersonSerializer, AffiliateWindowSerializer, ContainedCourseRunsSerializer, CourseRunSearchSerializer,
ProgramSerializer, ProgramSearchSerializer, ProgramCourseSerializer, NestedProgramSerializer, ProgramSerializer, ProgramSearchSerializer, ProgramCourseSerializer, NestedProgramSerializer,
CourseRunWithProgramsSerializer, CourseWithProgramsSerializer, CorporateEndorsementSerializer, CourseRunWithProgramsSerializer, CourseWithProgramsSerializer, CorporateEndorsementSerializer,
FAQSerializer, EndorsementSerializer, PositionSerializer, FlattenedCourseRunWithCourseSerializer, FAQSerializer, EndorsementSerializer, PositionSerializer, FlattenedCourseRunWithCourseSerializer
MinimalCourseSerializer, MinimalOrganizationSerializer, MinimalCourseRunSerializer, MinimalProgramSerializer,
CourseSerializer
) )
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import User
...@@ -32,7 +30,7 @@ from course_discovery.apps.course_metadata.tests.factories import ( ...@@ -32,7 +30,7 @@ from course_discovery.apps.course_metadata.tests.factories import (
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint:disable=no-member, test-inherits-tests # pylint:disable=no-member
def json_date_format(datetime_obj): def json_date_format(datetime_obj):
return datetime.strftime(datetime_obj, "%Y-%m-%dT%H:%M:%S.%fZ") return datetime.strftime(datetime_obj, "%Y-%m-%dT%H:%M:%S.%fZ")
...@@ -94,36 +92,19 @@ class CatalogSerializerTests(TestCase): ...@@ -94,36 +92,19 @@ class CatalogSerializerTests(TestCase):
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
class MinimalCourseSerializerTests(TestCase): class CourseSerializerTests(TestCase):
serializer_class = MinimalCourseSerializer
def get_expected_data(self, course, request):
context = {'request': request}
return {
'key': course.key,
'uuid': str(course.uuid),
'title': course.title,
'course_runs': MinimalCourseRunSerializer(course.course_runs, many=True, context=context).data,
'owners': MinimalOrganizationSerializer(course.authoring_organizations, many=True, context=context).data,
}
def test_data(self): def test_data(self):
request = make_request() course = CourseFactory()
organizations = OrganizationFactory() video = course.video
course = CourseFactory(authoring_organizations=[organizations])
CourseRunFactory.create_batch(2, course=course)
serializer = self.serializer_class(course, context={'request': request})
expected = self.get_expected_data(course, request)
self.assertDictEqual(serializer.data, expected)
request = make_request()
class CourseSerializerTests(MinimalCourseSerializerTests): CourseRunFactory.create_batch(3, course=course)
serializer_class = CourseSerializer serializer = CourseWithProgramsSerializer(course, context={'request': request})
def get_expected_data(self, course, request): expected = {
expected = super().get_expected_data(course, request) 'key': course.key,
expected.update({ 'title': course.title,
'short_description': course.short_description, 'short_description': course.short_description,
'full_description': course.full_description, 'full_description': course.full_description,
'level_type': course.level_type.name, 'level_type': course.level_type.name,
...@@ -131,9 +112,11 @@ class CourseSerializerTests(MinimalCourseSerializerTests): ...@@ -131,9 +112,11 @@ class CourseSerializerTests(MinimalCourseSerializerTests):
'prerequisites': [], 'prerequisites': [],
'expected_learning_items': [], 'expected_learning_items': [],
'image': ImageField().to_representation(course.card_image_url), 'image': ImageField().to_representation(course.card_image_url),
'video': VideoSerializer(course.video).data, 'video': VideoSerializer(video).data,
'owners': OrganizationSerializer(course.authoring_organizations, many=True).data,
'sponsors': OrganizationSerializer(course.sponsoring_organizations, many=True).data, 'sponsors': OrganizationSerializer(course.sponsoring_organizations, many=True).data,
'modified': json_date_format(course.modified), # pylint: disable=no-member 'modified': json_date_format(course.modified), # pylint: disable=no-member
'course_runs': CourseRunSerializer(course.course_runs, many=True, context={'request': request}).data,
'marketing_url': '{url}?{params}'.format( 'marketing_url': '{url}?{params}'.format(
url=course.marketing_url, url=course.marketing_url,
params=urlencode({ params=urlencode({
...@@ -141,47 +124,22 @@ class CourseSerializerTests(MinimalCourseSerializerTests): ...@@ -141,47 +124,22 @@ class CourseSerializerTests(MinimalCourseSerializerTests):
'utm_medium': request.user.referral_tracking_id, 'utm_medium': request.user.referral_tracking_id,
}) })
), ),
'course_runs': CourseRunSerializer(course.course_runs, many=True, context={'request': request}).data,
'owners': OrganizationSerializer(course.authoring_organizations, many=True).data,
})
return expected
class CourseWithProgramsSerializerTests(CourseSerializerTests): # pylint: disable=test-inherits-tests
serializer_class = CourseWithProgramsSerializer
def get_expected_data(self, course, request):
expected = super().get_expected_data(course, request)
expected.update({
'programs': NestedProgramSerializer(course.programs, many=True, context={'request': request}).data, 'programs': NestedProgramSerializer(course.programs, many=True, context={'request': request}).data,
}) }
return expected
class MinimalCourseRunSerializerTests(TestCase): self.assertDictEqual(serializer.data, expected)
serializer_class = MinimalCourseRunSerializer
def get_expected_data(self, course_run, request): # pylint: disable=unused-argument
return {
'key': course_run.key,
'uuid': str(course_run.uuid),
'title': course_run.title,
}
class CourseRunSerializerTests(TestCase):
def test_data(self): def test_data(self):
request = make_request() request = make_request()
course_run = CourseRunFactory() course_run = CourseRunFactory()
serializer = self.serializer_class(course_run, context={'request': request}) course = course_run.course
expected = self.get_expected_data(course_run, request) video = course_run.video
self.assertDictEqual(serializer.data, expected) serializer = CourseRunSerializer(course_run, context={'request': request})
ProgramFactory(courses=[course])
class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disable=test-inherits-tests
serializer_class = CourseRunSerializer
def get_expected_data(self, course_run, request): expected = {
expected = super().get_expected_data(course_run, request)
expected.update({
'course': course_run.course.key, 'course': course_run.course.key,
'key': course_run.key, 'key': course_run.key,
'title': course_run.title, # pylint: disable=no-member 'title': course_run.title, # pylint: disable=no-member
...@@ -193,7 +151,7 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa ...@@ -193,7 +151,7 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa
'enrollment_end': json_date_format(course_run.enrollment_end), 'enrollment_end': json_date_format(course_run.enrollment_end),
'announcement': json_date_format(course_run.announcement), 'announcement': json_date_format(course_run.announcement),
'image': ImageField().to_representation(course_run.card_image_url), 'image': ImageField().to_representation(course_run.card_image_url),
'video': VideoSerializer(course_run.video).data, 'video': VideoSerializer(video).data,
'pacing_type': course_run.pacing_type, 'pacing_type': course_run.pacing_type,
'content_language': course_run.language.code, 'content_language': course_run.language.code,
'transcript_languages': [], 'transcript_languages': [],
...@@ -212,8 +170,9 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa ...@@ -212,8 +170,9 @@ class CourseRunSerializerTests(MinimalCourseRunSerializerTests): # pylint: disa
), ),
'level_type': course_run.level_type.name, 'level_type': course_run.level_type.name,
'availability': course_run.availability, 'availability': course_run.availability,
}) }
return expected
self.assertDictEqual(serializer.data, expected)
class CourseRunWithProgramsSerializerTests(TestCase): class CourseRunWithProgramsSerializerTests(TestCase):
...@@ -325,96 +284,92 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover ...@@ -325,96 +284,92 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover
class ProgramCourseSerializerTests(TestCase): class ProgramCourseSerializerTests(TestCase):
def setUp(self): def setUp(self):
super(ProgramCourseSerializerTests, self).setUp() super(ProgramCourseSerializerTests, self).setUp()
self.request = make_request() self.program = ProgramFactory(courses=[CourseFactory()])
self.course_list = CourseFactory.create_batch(3)
self.program = ProgramFactory(courses=self.course_list) def assert_program_courses_serialized(self, program):
request = make_request()
def test_no_run(self):
"""
Make sure that if a course has no runs, the serializer still works as expected
"""
serializer = ProgramCourseSerializer( serializer = ProgramCourseSerializer(
self.course_list, program.courses,
many=True, many=True,
context={'request': self.request, 'program': self.program} context={
'request': request,
'program': program,
'course_runs': program.course_runs
}
) )
expected = CourseSerializer(program.courses, many=True, context={'request': request}).data
expected = MinimalCourseSerializer(self.course_list, many=True, context={'request': self.request}).data
self.assertSequenceEqual(serializer.data, expected) self.assertSequenceEqual(serializer.data, expected)
def test_with_runs(self): def test_data(self):
for course in self.course_list: for course in self.program.courses.all():
CourseRunFactory.create_batch(2, course=course) CourseRunFactory(course=course)
serializer = ProgramCourseSerializer(
self.course_list,
many=True,
context={'request': self.request, 'program': self.program}
)
expected = MinimalCourseSerializer(self.course_list, many=True, context={'request': self.request}).data self.assert_program_courses_serialized(self.program)
self.assertSequenceEqual(serializer.data, expected) def test_data_without_course_runs(self):
"""
Make sure that if a course has no runs, the serializer still works as expected
"""
self.assert_program_courses_serialized(self.program)
def test_with_exclusions(self): def test_with_exclusions(self):
""" """
Test serializer with course_run exclusions within program Test serializer with course_run exclusions within program
""" """
request = make_request()
course = CourseFactory() course = CourseFactory()
excluded_runs = [] excluded_runs = []
course_runs = CourseRunFactory.create_batch(2, course=course) course_runs = CourseRunFactory.create_batch(2, course=course)
excluded_runs.append(course_runs[0]) excluded_runs.append(course_runs[0])
program = ProgramFactory(courses=[course], excluded_course_runs=excluded_runs) program = ProgramFactory(courses=[course], excluded_course_runs=excluded_runs)
serializer_context = {'request': self.request, 'program': program} serializer_context = {'request': request, 'program': program, 'course_runs': program.course_runs}
serializer = ProgramCourseSerializer(course, context=serializer_context) serializer = ProgramCourseSerializer(course, context=serializer_context)
expected = MinimalCourseSerializer(course, context=serializer_context).data expected = CourseSerializer(course, context=serializer_context).data
expected['course_runs'] = MinimalCourseRunSerializer([course_runs[1]], many=True, expected['course_runs'] = CourseRunSerializer([course_runs[1]], many=True,
context={'request': self.request}).data context={'request': request}).data
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
def test_with_published_course_runs_only_context(self): def test_with_published_course_runs_only_context(self):
""" Verify setting the published_course_runs_only context value excludes unpublished course runs. """ """ Verify setting the published_course_runs_only context value excludes unpublished course runs. """
# Create a program and course. The course should have both published and un-published course runs. # Create a program and course. The course should have both published and un-published course runs.
request = make_request()
course = CourseFactory() course = CourseFactory()
courses = [course] program = ProgramFactory(courses=[course])
program = ProgramFactory(courses=courses)
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=course) unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=course)
CourseRunFactory(status=CourseRunStatus.Published, course=course) CourseRunFactory(status=CourseRunStatus.Published, course=course)
# We do NOT expect the results to included the unpublished data # We do NOT expect the results to included the unpublished data
expected = MinimalCourseSerializer(courses, many=True, context={'request': self.request}).data expected = CourseSerializer(course, context={'request': request}).data
expected[0]['course_runs'] = [course_run for course_run in expected[0]['course_runs'] if expected['course_runs'] = [course_run for course_run in expected['course_runs'] if
course_run['uuid'] != str(unpublished_course_run.uuid)] course_run['key'] != str(unpublished_course_run.key)]
self.assertEqual(len(expected[0]['course_runs']), 1) self.assertEqual(len(expected['course_runs']), 1)
serializer = ProgramCourseSerializer( serializer = ProgramCourseSerializer(
courses, course,
many=True,
context={ context={
'request': self.request, 'request': request,
'program': program, 'program': program,
'published_course_runs_only': True, 'published_course_runs_only': True,
'course_runs': program.course_runs,
} }
) )
self.assertSequenceEqual(serializer.data, expected) self.assertSequenceEqual(serializer.data, expected)
class MinimalProgramSerializerTests(TestCase): class ProgramSerializerTests(TestCase):
serializer_class = MinimalProgramSerializer
def create_program(self): def create_program(self):
organizations = OrganizationFactory.create_batch(2) organizations = [OrganizationFactory()]
person = PersonFactory() person = PersonFactory()
courses = CourseFactory.create_batch(3) course = CourseFactory()
for course in courses: CourseRunFactory(course=course, staff=[person])
CourseRunFactory.create_batch(2, course=course, staff=[person])
program = ProgramFactory( program = ProgramFactory(
courses=courses, courses=[course],
authoring_organizations=organizations, authoring_organizations=organizations,
credit_backing_organizations=organizations, credit_backing_organizations=organizations,
corporate_endorsements=CorporateEndorsementFactory.create_batch(1), corporate_endorsements=CorporateEndorsementFactory.create_batch(1),
...@@ -439,28 +394,17 @@ class MinimalProgramSerializerTests(TestCase): ...@@ -439,28 +394,17 @@ class MinimalProgramSerializerTests(TestCase):
'marketing_slug': program.marketing_slug, 'marketing_slug': program.marketing_slug,
'marketing_url': program.marketing_url, 'marketing_url': program.marketing_url,
'banner_image': image_field.to_representation(program.banner_image), 'banner_image': image_field.to_representation(program.banner_image),
'courses': ProgramCourseSerializer(program.courses, many=True, 'banner_image_url': program.banner_image_url,
context={'request': request, 'program': program}).data, 'courses': ProgramCourseSerializer(
'authoring_organizations': MinimalOrganizationSerializer(program.authoring_organizations, many=True).data, program.courses,
many=True,
context={
'request': request,
'program': program,
'course_runs': program.course_runs,
}).data,
'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'card_image_url': program.card_image_url, 'card_image_url': program.card_image_url,
}
def test_data(self):
request = make_request()
program = self.create_program()
serializer = self.serializer_class(program, context={'request': request})
expected = self.get_expected_data(program, request)
self.assertDictEqual(serializer.data, expected)
class ProgramSerializerTests(MinimalProgramSerializerTests): # pylint: disable=test-inherits-tests
serializer_class = ProgramSerializer
def get_expected_data(self, program, request):
expected = super().get_expected_data(program, request)
expected.update({
'marketing_slug': program.marketing_slug,
'marketing_url': program.marketing_url,
'video': VideoSerializer(program.video).data, 'video': VideoSerializer(program.video).data,
'credit_redemption_overview': program.credit_redemption_overview, 'credit_redemption_overview': program.credit_redemption_overview,
'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data, 'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data,
...@@ -478,11 +422,17 @@ class ProgramSerializerTests(MinimalProgramSerializerTests): # pylint: disable= ...@@ -478,11 +422,17 @@ class ProgramSerializerTests(MinimalProgramSerializerTests): # pylint: disable=
'max_hours_effort_per_week': program.max_hours_effort_per_week, 'max_hours_effort_per_week': program.max_hours_effort_per_week,
'min_hours_effort_per_week': program.min_hours_effort_per_week, 'min_hours_effort_per_week': program.min_hours_effort_per_week,
'overview': program.overview, 'overview': program.overview,
'price_ranges': [], 'price_ranges': program.price_ranges,
'subjects': SubjectSerializer(program.subjects, many=True).data, 'subjects': SubjectSerializer(program.subjects, many=True).data,
'transcript_languages': [serialize_language_to_code(l) for l in program.transcript_languages], 'transcript_languages': [serialize_language_to_code(l) for l in program.transcript_languages],
}) }
return expected
def test_data(self):
request = make_request()
program = self.create_program()
serializer = ProgramSerializer(program, context={'request': request})
expected = self.get_expected_data(program, request)
self.assertDictEqual(dict(serializer.data), expected)
def test_data_with_exclusions(self): def test_data_with_exclusions(self):
""" """
...@@ -622,45 +572,24 @@ class VideoSerializerTests(TestCase): ...@@ -622,45 +572,24 @@ class VideoSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
class MinimalOrganizationSerializerTests(TestCase): class OrganizationSerializerTests(TestCase):
serializer_class = MinimalOrganizationSerializer def test_data(self):
organization = OrganizationFactory()
def create_organization(self): TAG = 'test'
return OrganizationFactory() organization.tags.add(TAG)
serializer = OrganizationSerializer(organization)
def get_expected_data(self, organization): expected = {
return {
'uuid': str(organization.uuid),
'key': organization.key, 'key': organization.key,
'name': organization.name, 'name': organization.name,
}
def test_data(self):
organization = self.create_organization()
serializer = self.serializer_class(organization)
expected = self.get_expected_data(organization)
self.assertDictEqual(serializer.data, expected)
class OrganizationSerializerTests(MinimalOrganizationSerializerTests):
TAG = 'test-tag'
serializer_class = OrganizationSerializer
def create_organization(self):
organization = super().create_organization()
organization.tags.add(self.TAG)
return organization
def get_expected_data(self, organization):
expected = super().get_expected_data(organization)
expected.update({
'description': organization.description, 'description': organization.description,
'homepage_url': organization.homepage_url, 'homepage_url': organization.homepage_url,
'logo_image_url': organization.logo_image_url, 'logo_image_url': organization.logo_image_url,
'tags': [self.TAG], 'tags': [TAG],
'marketing_url': organization.marketing_url, 'marketing_url': organization.marketing_url,
}) }
return expected
self.assertDictEqual(serializer.data, expected)
class SeatSerializerTests(TestCase): class SeatSerializerTests(TestCase):
...@@ -806,7 +735,7 @@ class ProgramSearchSerializerTests(TestCase): ...@@ -806,7 +735,7 @@ class ProgramSearchSerializerTests(TestCase):
'partner': program.partner.short_code, 'partner': program.partner.short_code,
'authoring_organization_uuids': get_uuids(program.authoring_organizations.all()), 'authoring_organization_uuids': get_uuids(program.authoring_organizations.all()),
'subject_uuids': get_uuids([course.subjects for course in program.courses.all()]), 'subject_uuids': get_uuids([course.subjects for course in program.courses.all()]),
'staff_uuids': get_uuids([course.staff for course in program.course_runs.all()]) 'staff_uuids': get_uuids([course.staff for course in program.course_runs])
} }
def test_data(self): def test_data(self):
......
...@@ -40,14 +40,14 @@ class ProgramViewSetTests(APITestCase): ...@@ -40,14 +40,14 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = ProgramFactory() program = ProgramFactory()
with self.assertNumQueries(34): with self.assertNumQueries(33):
self.assert_retrieve_success(program) self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self): def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """ """ Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory() course = CourseFactory()
program = ProgramFactory(courses=[course]) program = ProgramFactory(courses=[course])
with self.assertNumQueries(49): with self.assertNumQueries(55):
self.assert_retrieve_success(program) self.assert_retrieve_success(program)
def assert_list_results(self, url, expected, expected_query_count): def assert_list_results(self, url, expected, expected_query_count):
...@@ -76,14 +76,14 @@ class ProgramViewSetTests(APITestCase): ...@@ -76,14 +76,14 @@ class ProgramViewSetTests(APITestCase):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3) expected = ProgramFactory.create_batch(3)
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected, 40) self.assert_list_results(self.list_path, expected, 14)
def test_filter_by_type(self): def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """ """ Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo' program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name) program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program], 18) self.assert_list_results(url, [program], 14)
url = self.list_path + '?type=bar' url = self.list_path + '?type=bar'
self.assert_list_results(url, [], 4) self.assert_list_results(url, [], 4)
...@@ -98,11 +98,11 @@ class ProgramViewSetTests(APITestCase): ...@@ -98,11 +98,11 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out. # Create a third program, which should be filtered out.
ProgramFactory() ProgramFactory()
self.assert_list_results(url, expected, 29) self.assert_list_results(url, expected, 14)
@ddt.data( @ddt.data(
(ProgramStatus.Unpublished, False, 4), (ProgramStatus.Unpublished, False, 4),
(ProgramStatus.Active, True, 40), (ProgramStatus.Active, True, 14),
) )
@ddt.unpack @ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count): def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
......
...@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset): ...@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset):
return queryset return queryset
def prefetch_related_objects_for_programs(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Program`.
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
course = serializers.PREFETCH_FIELDS['course'] + serializers.SELECT_RELATED_FIELDS['course']
course = ['courses__' + field for field in course]
queryset = queryset.prefetch_related(*course)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['program'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['program'])
return queryset
# pylint: disable=no-member # pylint: disable=no-member
class CatalogViewSet(viewsets.ModelViewSet): class CatalogViewSet(viewsets.ModelViewSet):
""" Catalog resource. """ """ Catalog resource. """
...@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
""" Program resource. """ """ Program resource. """
lookup_field = 'uuid' lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+' lookup_value_regex = '[0-9a-f-]+'
queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter filter_class = filters.ProgramFilter
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.serializer_class.prefetch_queryset()
def get_serializer_context(self, *args, **kwargs): def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs) context = super().get_serializer_context(*args, **kwargs)
context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0)) context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0))
......
import datetime import datetime
import itertools import itertools
import logging import logging
from collections import defaultdict
from urllib.parse import urljoin from urllib.parse import urljoin
from uuid import uuid4 from uuid import uuid4
...@@ -635,57 +636,66 @@ class Program(TimeStampedModel): ...@@ -635,57 +636,66 @@ class Program(TimeStampedModel):
@property @property
def course_runs(self): def course_runs(self):
"""
Warning! Only call this method after retrieving programs from `ProgramSerializer.prefetch_queryset()`.
Otherwise, this method will incur many, many queries when fetching related courses and course runs.
"""
excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()] excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()]
return CourseRun.objects.filter(course__programs=self).exclude(id__in=excluded_course_run_ids)
for course in self.courses.all():
for run in course.course_runs.all():
if run.id not in excluded_course_run_ids:
yield run
@property @property
def languages(self): def languages(self):
course_runs = self.course_runs.select_related('language') return set(course_run.language for course_run in self.course_runs if course_run.language is not None)
return set(course_run.language for course_run in course_runs if course_run.language is not None)
@property @property
def transcript_languages(self): def transcript_languages(self):
course_runs = self.course_runs.prefetch_related('transcript_languages') languages = [course_run.transcript_languages.all() for course_run in self.course_runs]
languages = [list(course_run.transcript_languages.all()) for course_run in course_runs]
languages = itertools.chain.from_iterable(languages) languages = itertools.chain.from_iterable(languages)
return set(languages) return set(languages)
@property @property
def subjects(self): def subjects(self):
courses = self.courses.prefetch_related('subjects') subjects = [course.subjects.all() for course in self.courses.all()]
subjects = [list(course.subjects.all()) for course in courses]
subjects = itertools.chain.from_iterable(subjects) subjects = itertools.chain.from_iterable(subjects)
return set(subjects) return set(subjects)
@property @property
def seats(self): def seats(self):
applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True) applicable_seat_types = set(seat_type.slug for seat_type in self.type.applicable_seat_types.all())
return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) \
.select_related('currency') for run in self.course_runs:
for seat in run.seats.all():
if seat.type in applicable_seat_types:
yield seat
@property @property
def seat_types(self): def seat_types(self):
return set(self.seats.values_list('type', flat=True)) return set(seat.type for seat in self.seats)
@property @property
def price_ranges(self): def price_ranges(self):
seats = self.seats.values('currency').annotate(models.Min('price'), models.Max('price')) currencies = defaultdict(list)
price_ranges = [] for seat in self.seats:
currencies[seat.currency].append(seat.price)
for seat in seats: price_ranges = []
for currency, prices in currencies.items():
price_ranges.append({ price_ranges.append({
'currency': seat['currency'], 'currency': currency.code,
'min': seat['price__min'], 'min': min(prices),
'max': seat['price__max'], 'max': max(prices),
}) })
return price_ranges return price_ranges
@property @property
def start(self): def start(self):
""" Start datetime, calculated by determining the earliest start datetime of all related course runs. """ """ Start datetime, calculated by determining the earliest start datetime of all related course runs. """
course_runs = self.course_runs if self.course_runs:
if course_runs:
start_dates = [course_run.start for course_run in self.course_runs if course_run.start] start_dates = [course_run.start for course_run in self.course_runs if course_run.start]
if start_dates: if start_dates:
...@@ -695,7 +705,7 @@ class Program(TimeStampedModel): ...@@ -695,7 +705,7 @@ class Program(TimeStampedModel):
@property @property
def staff(self): def staff(self):
staff = [list(course_run.staff.all()) for course_run in self.course_runs] staff = [course_run.staff.all() for course_run in self.course_runs]
staff = itertools.chain.from_iterable(staff) staff = itertools.chain.from_iterable(staff)
return set(staff) return set(staff)
......
...@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin): ...@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin):
return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()] return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()]
def prepare_staff_uuids(self, obj): def prepare_staff_uuids(self, obj):
return [str(staff.uuid) for course_run in obj.course_runs.all() for staff in course_run.staff.all()] return [str(staff.uuid) for course_run in obj.course_runs for staff in course_run.staff.all()]
def prepare_credit_backing_organizations(self, obj): def prepare_credit_backing_organizations(self, obj):
return self._prepare_organizations(obj.credit_backing_organizations.all()) return self._prepare_organizations(obj.credit_backing_organizations.all())
......
...@@ -100,7 +100,7 @@ class AdminTests(TestCase): ...@@ -100,7 +100,7 @@ class AdminTests(TestCase):
""" Verify that course selection page with posting the data. """ """ Verify that course selection page with posting the data. """
self.assertEqual(1, self.program.excluded_course_runs.all().count()) self.assertEqual(1, self.program.excluded_course_runs.all().count())
self.assertEqual(3, len(self.program.course_runs.all())) self.assertEqual(3, sum(1 for _ in self.program.course_runs))
params = { params = {
'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id], 'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id],
...@@ -114,7 +114,7 @@ class AdminTests(TestCase): ...@@ -114,7 +114,7 @@ class AdminTests(TestCase):
target_status_code=200 target_status_code=200
) )
self.assertEqual(2, self.program.excluded_course_runs.all().count()) self.assertEqual(2, self.program.excluded_course_runs.all().count())
self.assertEqual(2, len(self.program.course_runs.all())) self.assertEqual(2, sum(1 for _ in self.program.course_runs))
def test_page_with_post_without_course_run(self): def test_page_with_post_without_course_run(self):
""" Verify that course selection page without posting any selected excluded check run. """ """ Verify that course selection page without posting any selected excluded check run. """
...@@ -132,7 +132,7 @@ class AdminTests(TestCase): ...@@ -132,7 +132,7 @@ class AdminTests(TestCase):
target_status_code=200 target_status_code=200
) )
self.assertEqual(0, self.program.excluded_course_runs.all().count()) self.assertEqual(0, self.program.excluded_course_runs.all().count())
self.assertEqual(4, len(self.program.course_runs.all())) self.assertEqual(4, sum(1 for _ in self.program.course_runs))
response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,))) response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,)))
self.assertNotContains(response, '<input checked="checked")') self.assertNotContains(response, '<input checked="checked")')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment