Commit bc3d7d01 by Renzo Lucioni

Program query optimizations

parent 3dd051a1
......@@ -5,6 +5,7 @@ from urllib.parse import urlencode
import pytz
from django.contrib.auth import get_user_model
from django.db.models.query import Prefetch
from django.utils.translation import ugettext_lazy as _
from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer
from rest_framework import serializers
......@@ -69,25 +70,39 @@ PROGRAM_FACET_FIELDS = BASE_PROGRAM_FIELDS + ('organizations',)
PREFETCH_FIELDS = {
'course_run': [
'course__partner', 'course__level_type', 'course__programs', 'course__programs__type',
'course__programs__partner', 'seats', 'transcript_languages', 'seats__currency', 'staff',
'staff__position', 'staff__position__organization', 'language',
'course__level_type',
'course__partner',
'course__programs',
'course__programs__partner',
'course__programs__type',
'language',
'seats',
'seats__currency',
'staff',
'staff__position',
'staff__position__organization',
'transcript_languages',
],
'course': [
'level_type', 'video', 'programs', 'course_runs', 'subjects', 'prerequisites', 'expected_learning_items',
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner',
'sponsoring_organizations', 'sponsoring_organizations__tags', 'sponsoring_organizations__partner',
],
'program': [
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner',
'excluded_course_runs', 'courses', 'courses__authoring_organizations', 'courses__course_runs',
'authoring_organizations',
'authoring_organizations__partner',
'authoring_organizations__tags',
'course_runs',
'expected_learning_items',
'level_type',
'prerequisites',
'programs',
'sponsoring_organizations',
'sponsoring_organizations__partner',
'sponsoring_organizations__tags',
'subjects',
'video',
],
}
SELECT_RELATED_FIELDS = {
'course': ['level_type', 'video', ],
'course_run': ['course', 'language', 'video', ],
'program': ['type', 'video', 'partner', ],
'course': ['level_type', 'partner', 'video'],
'course_run': ['course', 'language', 'video'],
}
......@@ -184,6 +199,10 @@ class PersonSerializer(serializers.ModelSerializer):
"""Serializer for the ``Person`` model."""
position = PositionSerializer()
@classmethod
def prefetch_queryset(cls):
return Person.objects.all().select_related('position__organization')
class Meta(object):
model = Person
fields = ('uuid', 'given_name', 'family_name', 'bio', 'profile_image_url', 'slug', 'position')
......@@ -193,6 +212,10 @@ class EndorsementSerializer(serializers.ModelSerializer):
"""Serializer for the ``Endorsement`` model."""
endorser = PersonSerializer()
@classmethod
def prefetch_queryset(cls):
return Endorsement.objects.all().select_related('endorser')
class Meta(object):
model = Endorsement
fields = ('endorser', 'quote',)
......@@ -203,6 +226,12 @@ class CorporateEndorsementSerializer(serializers.ModelSerializer):
image = ImageSerializer()
individual_endorsements = EndorsementSerializer(many=True)
@classmethod
def prefetch_queryset(cls):
return CorporateEndorsement.objects.all().select_related('image').prefetch_related(
Prefetch('endorser', queryset=EndorsementSerializer.prefetch_queryset()),
)
class Meta(object):
model = CorporateEndorsement
fields = ('corporation_name', 'statement', 'image', 'individual_endorsements',)
......@@ -222,6 +251,10 @@ class SeatSerializer(serializers.ModelSerializer):
credit_provider = serializers.CharField()
credit_hours = serializers.IntegerField()
@classmethod
def prefetch_queryset(cls):
return Seat.objects.all().select_related('currency')
class Meta(object):
model = Seat
fields = ('type', 'price', 'currency', 'upgrade_deadline', 'credit_provider', 'credit_hours',)
......@@ -231,6 +264,10 @@ class OrganizationSerializer(TaggitSerializer, serializers.ModelSerializer):
"""Serializer for the ``Organization`` model."""
tags = TagListSerializerField()
@classmethod
def prefetch_queryset(cls):
return Organization.objects.all().select_related('partner').prefetch_related('tags')
class Meta(object):
model = Organization
fields = ('key', 'name', 'description', 'homepage_url', 'tags', 'logo_image_url', 'marketing_url')
......@@ -287,7 +324,15 @@ class CourseRunSerializer(TimestampModelSerializer):
marketing_url = serializers.SerializerMethodField()
level_type = serializers.SlugRelatedField(read_only=True, slug_field='name')
class Meta(object):
@classmethod
def prefetch_queryset(cls):
return CourseRun.objects.all().select_related('course', 'language', 'video').prefetch_related(
'transcript_languages',
Prefetch('seats', queryset=SeatSerializer.prefetch_queryset()),
Prefetch('staff', queryset=PersonSerializer.prefetch_queryset()),
)
class Meta:
model = CourseRun
fields = (
'course', 'key', 'title', 'short_description', 'full_description', 'start', 'end',
......@@ -333,7 +378,18 @@ class CourseSerializer(TimestampModelSerializer):
course_runs = CourseRunSerializer(many=True)
marketing_url = serializers.SerializerMethodField()
class Meta(object):
@classmethod
def prefetch_queryset(cls):
return Course.objects.all().select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
class Meta:
model = Course
fields = (
'key', 'title', 'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites',
......@@ -411,6 +467,31 @@ class ProgramSerializer(serializers.ModelSerializer):
subjects = SubjectSerializer(many=True)
staff = PersonSerializer(many=True)
@classmethod
def prefetch_queryset(cls):
"""
Prefetch the related objects that will be serialized with a `Program`.
We use Pefetch objects so that we can prefetch and select all the way down the
chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching).
"""
return Program.objects.all().select_related('type', 'video', 'partner').prefetch_related(
'excluded_course_runs',
'expected_learning_items',
'faq',
'job_outlook_items',
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
# `ProgramType` on `Program`.
'type__applicable_seat_types',
Prefetch('courses', queryset=ProgramCourseSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()),
)
def get_courses(self, program):
courses, course_runs = self.sort_courses(program)
......@@ -437,9 +518,7 @@ class ProgramSerializer(serializers.ModelSerializer):
course_runs should never be empty. If it is, key functions in this method attempting to find the
min of an empty sequence will raise a ValueError.
"""
course_runs = program.course_runs.select_related(*SELECT_RELATED_FIELDS['course_run'])
course_runs = course_runs.prefetch_related(*PREFETCH_FIELDS['course_run'])
course_runs = list(course_runs)
course_runs = list(program.course_runs)
def min_run_enrollment_start(course):
# Enrollment starts may be empty. When this is the case, we make the same assumption as
......@@ -492,7 +571,7 @@ class ProgramSerializer(serializers.ModelSerializer):
'authoring_organizations', 'banner_image', 'banner_image_url', 'card_image_url', 'video',
'expected_learning_items', 'faq', 'credit_backing_organizations', 'corporate_endorsements',
'job_outlook_items', 'individual_endorsements', 'languages', 'transcript_languages', 'subjects',
'price_ranges', 'staff', 'credit_redemption_overview'
'price_ranges', 'staff', 'credit_redemption_overview',
)
read_only_fields = ('uuid', 'marketing_url', 'banner_image')
......@@ -511,7 +590,7 @@ class AffiliateWindowSerializer(serializers.ModelSerializer):
category = serializers.SerializerMethodField()
price = serializers.SerializerMethodField()
class Meta(object):
class Meta:
model = Seat
fields = (
'name', 'pid', 'desc', 'category', 'purl', 'imgurl', 'price', 'currency'
......@@ -539,7 +618,7 @@ class FlattenedCourseRunWithCourseSerializer(CourseRunSerializer):
course_key = serializers.SlugRelatedField(read_only=True, source='course', slug_field='key')
image = ImageField(read_only=True, source='card_image_url')
class Meta(object):
class Meta:
model = CourseRun
fields = (
'key', 'title', 'short_description', 'full_description', 'level_type', 'subjects', 'prerequisites',
......
from django.test import TestCase
from course_discovery.apps.api.fields import ImageField
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.tests.test_serializers import make_request
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory
class ImageFieldTests(TestCase):
......@@ -13,3 +16,27 @@ class ImageFieldTests(TestCase):
'width': None
}
self.assertEqual(ImageField().to_representation(value), expected)
# pylint: disable=no-member
class StdImageSerializerFieldTests(TestCase):
def test_to_representation(self):
request = make_request()
# TODO Create test-only model to avoid unnecessary dependency on Program model.
program = ProgramFactory(banner_image=make_image_file('test.jpg'))
field = StdImageSerializerField()
field._context = {'request': request} # pylint: disable=protected-access
expected = {
size_key: {
'url': '{}{}'.format(
'http://testserver',
getattr(program.banner_image, size_key).url
),
'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height']
}
for size_key in program.banner_image.field.variations
}
self.assertDictEqual(field.to_representation(program.banner_image), expected)
import unittest
from datetime import datetime
from urllib.parse import urlencode
......@@ -8,7 +7,7 @@ from haystack.query import SearchQuerySet
from opaque_keys.edx.keys import CourseKey
from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.fields import ImageField
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.serializers import (
CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer, ImageSerializer,
SubjectSerializer, PrerequisiteSerializer, VideoSerializer, OrganizationSerializer, SeatSerializer,
......@@ -192,7 +191,6 @@ class CourseRunWithProgramsSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected)
@unittest.skip('This test is disabled until we can determine why assertDictEqual() fails for two equivalent inputs.')
class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover
def serialize_seats(self, course_run):
seats = {
......@@ -242,7 +240,7 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover
def get_expected_data(self, request, course_run):
course = course_run.course
serializer_context = {'request': request}
expected = CourseRunSerializer(course_run, context=serializer_context).data
expected = dict(CourseRunSerializer(course_run, context=serializer_context).data)
expected.update({
'subjects': self.serialize_items(course.subjects.all(), 'name'),
'seats': self.serialize_seats(course_run),
......@@ -283,371 +281,173 @@ class FlattenedCourseRunWithCourseSerializerTests(TestCase): # pragma: no cover
self.assertDictEqual(serializer.data, expected)
@ddt.ddt
class ProgramCourseSerializerTests(TestCase):
def setUp(self):
super(ProgramCourseSerializerTests, self).setUp()
self.request = make_request()
self.course_list = CourseFactory.create_batch(3)
self.program = ProgramFactory(courses=self.course_list)
self.program = ProgramFactory(courses=[CourseFactory()])
def assert_program_courses_serialized(self, program):
request = make_request()
def test_no_run(self):
"""
Make sure that if a course has no runs, the serializer still works as expected
"""
serializer = ProgramCourseSerializer(
self.course_list,
program.courses,
many=True,
context={'request': self.request, 'program': self.program, 'course_runs': self.program.course_runs}
context={
'request': request,
'program': program,
'course_runs': program.course_runs
}
)
expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
expected = CourseSerializer(program.courses, many=True, context={'request': request}).data
self.assertSequenceEqual(serializer.data, expected)
def test_with_runs(self):
for course in self.course_list:
CourseRunFactory.create_batch(2, course=course)
serializer = ProgramCourseSerializer(
self.course_list,
many=True,
context={'request': self.request, 'program': self.program, 'course_runs': self.program.course_runs}
)
def test_data(self):
for course in self.program.courses.all():
CourseRunFactory(course=course)
expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
self.assert_program_courses_serialized(self.program)
self.assertSequenceEqual(serializer.data, expected)
def test_data_without_course_runs(self):
"""
Make sure that if a course has no runs, the serializer still works as expected
"""
self.assert_program_courses_serialized(self.program)
def test_with_exclusions(self):
"""
Test serializer with course_run exclusions within program
"""
request = make_request()
course = CourseFactory()
excluded_runs = []
course_runs = CourseRunFactory.create_batch(2, course=course)
excluded_runs.append(course_runs[0])
program = ProgramFactory(courses=[course], excluded_course_runs=excluded_runs)
serializer_context = {'request': self.request, 'program': program, 'course_runs': program.course_runs}
serializer_context = {'request': request, 'program': program, 'course_runs': program.course_runs}
serializer = ProgramCourseSerializer(course, context=serializer_context)
expected = CourseSerializer(course, context=serializer_context).data
expected['course_runs'] = CourseRunSerializer([course_runs[1]], many=True,
context={'request': self.request}).data
context={'request': request}).data
self.assertDictEqual(serializer.data, expected)
@ddt.data(
[CourseRunStatus.Unpublished, 1],
[CourseRunStatus.Unpublished, 0],
[CourseRunStatus.Published, 1],
[CourseRunStatus.Published, 0]
)
@ddt.unpack
def test_with_published_only_querystring(self, course_run_status, published_course_runs_only):
"""
Test the serializer's ability to filter out course_runs based on
"published_course_runs_only" query string
"""
expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
def test_with_published_course_runs_only_context(self):
""" Verify setting the published_course_runs_only context value excludes unpublished course runs. """
# Create a program and course. The course should have both published and un-published course runs.
request = make_request()
course = CourseFactory()
program = ProgramFactory(courses=[course])
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=course)
CourseRunFactory(status=CourseRunStatus.Published, course=course)
# We do NOT expect the results to included the unpublished data
expected = CourseSerializer(course, context={'request': request}).data
expected['course_runs'] = [course_run for course_run in expected['course_runs'] if
course_run['key'] != str(unpublished_course_run.key)]
self.assertEqual(len(expected['course_runs']), 1)
for course in self.course_list:
CourseRunFactory.create_batch(2, status=course_run_status, course=course)
serializer = ProgramCourseSerializer(
self.course_list,
many=True,
course,
context={
'request': self.request,
'program': self.program,
'published_course_runs_only': published_course_runs_only,
'course_runs': self.program.course_runs
'request': request,
'program': program,
'published_course_runs_only': True,
'course_runs': program.course_runs,
}
)
validate_data = serializer.data
if not published_course_runs_only or course_run_status != CourseRunStatus.Unpublished:
expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
self.assertSequenceEqual(validate_data, expected)
self.assertSequenceEqual(serializer.data, expected)
class ProgramSerializerTests(TestCase):
def test_data(self):
request = make_request()
org_list = OrganizationFactory.create_batch(1)
course_list = CourseFactory.create_batch(3)
for course in course_list:
CourseRunFactory.create_batch(
3,
course=course,
enrollment_start=datetime(2014, 1, 1),
start=datetime(2014, 1, 1)
)
corporate_endorsements = CorporateEndorsementFactory.create_batch(1)
individual_endorsements = EndorsementFactory.create_batch(1)
staff = PersonFactory.create_batch(1)
job_outlook_items = JobOutlookItemFactory.create_batch(1)
expected_learning_items = ExpectedLearningItemFactory.create_batch(1)
program = ProgramFactory(
authoring_organizations=org_list,
courses=course_list,
credit_backing_organizations=org_list,
corporate_endorsements=corporate_endorsements,
individual_endorsements=individual_endorsements,
expected_learning_items=expected_learning_items,
staff=staff,
job_outlook_items=job_outlook_items,
)
program.banner_image = make_image_file('test_banner.jpg')
program.save()
serializer = ProgramSerializer(program, context={'request': request})
expected_banner_image_urls = {
size_key: {
'url': '{}{}'.format(
'http://testserver',
getattr(program.banner_image, size_key).url
),
'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height']
}
for size_key in program.banner_image.field.variations
}
expected = {
'uuid': str(program.uuid),
'title': program.title,
'subtitle': program.subtitle,
'type': program.type.name,
'marketing_slug': program.marketing_slug,
'marketing_url': program.marketing_url,
'card_image_url': program.card_image_url,
'banner_image_url': program.banner_image_url,
'video': None,
'banner_image': expected_banner_image_urls,
'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'credit_redemption_overview': program.credit_redemption_overview,
'courses': ProgramCourseSerializer(
program.courses,
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data,
'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data,
'credit_backing_organizations': OrganizationSerializer(
program.credit_backing_organizations,
many=True
).data,
'expected_learning_items': [item.value for item in program.expected_learning_items.all()],
'faq': FAQSerializer(program.faq, many=True).data,
'individual_endorsements': EndorsementSerializer(program.individual_endorsements, many=True).data,
'staff': PersonSerializer(program.staff, many=True).data,
'job_outlook_items': [item.value for item in program.job_outlook_items.all()],
'languages': [serialize_language_to_code(l) for l in program.languages],
'weeks_to_complete': program.weeks_to_complete,
'max_hours_effort_per_week': None,
'min_hours_effort_per_week': None,
'overview': None,
'price_ranges': [],
'status': program.status,
'subjects': [],
'transcript_languages': [],
}
def create_program(self):
organizations = [OrganizationFactory()]
person = PersonFactory()
self.assertDictEqual(serializer.data, expected)
def test_with_exclusions(self):
"""
Verify we can specify program excluded_course_runs and the serializers will
render the course_runs with exclusions
"""
request = make_request()
org_list = OrganizationFactory.create_batch(1)
course_list = CourseFactory.create_batch(4)
excluded_runs = []
for course in course_list:
course_runs = CourseRunFactory.create_batch(
3,
course=course,
enrollment_start=datetime(2014, 1, 1),
start=datetime(2014, 1, 1)
)
excluded_runs.append(course_runs[0])
course = CourseFactory()
CourseRunFactory(course=course, staff=[person])
program = ProgramFactory(
authoring_organizations=org_list,
courses=course_list,
excluded_course_runs=excluded_runs
courses=[course],
authoring_organizations=organizations,
credit_backing_organizations=organizations,
corporate_endorsements=CorporateEndorsementFactory.create_batch(1),
individual_endorsements=EndorsementFactory.create_batch(1),
expected_learning_items=ExpectedLearningItemFactory.create_batch(1),
job_outlook_items=JobOutlookItemFactory.create_batch(1),
banner_image=make_image_file('test_banner.jpg'),
video=VideoFactory()
)
serializer = ProgramSerializer(program, context={'request': request})
return program
expected = {
def get_expected_data(self, program, request):
image_field = StdImageSerializerField()
image_field._context = {'request': request} # pylint: disable=protected-access
return {
'uuid': str(program.uuid),
'title': program.title,
'subtitle': program.subtitle,
'type': program.type.name,
'status': program.status,
'marketing_slug': program.marketing_slug,
'marketing_url': program.marketing_url,
'card_image_url': program.card_image_url,
'banner_image': {},
'banner_image': image_field.to_representation(program.banner_image),
'banner_image_url': program.banner_image_url,
'video': None,
'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'credit_redemption_overview': program.credit_redemption_overview,
'courses': ProgramCourseSerializer(
program.courses,
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data,
context={
'request': request,
'program': program,
'course_runs': program.course_runs,
}).data,
'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'card_image_url': program.card_image_url,
'video': VideoSerializer(program.video).data,
'credit_redemption_overview': program.credit_redemption_overview,
'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data,
'credit_backing_organizations': OrganizationSerializer(
program.credit_backing_organizations,
many=True
).data,
'expected_learning_items': [],
'expected_learning_items': [item.value for item in program.expected_learning_items.all()],
'faq': FAQSerializer(program.faq, many=True).data,
'individual_endorsements': EndorsementSerializer(program.individual_endorsements, many=True).data,
'staff': PersonSerializer(program.staff, many=True).data,
'job_outlook_items': [],
'job_outlook_items': [item.value for item in program.job_outlook_items.all()],
'languages': [serialize_language_to_code(l) for l in program.languages],
'weeks_to_complete': program.weeks_to_complete,
'max_hours_effort_per_week': None,
'min_hours_effort_per_week': None,
'overview': None,
'price_ranges': [],
'status': program.status,
'subjects': [],
'transcript_languages': [],
'max_hours_effort_per_week': program.max_hours_effort_per_week,
'min_hours_effort_per_week': program.min_hours_effort_per_week,
'overview': program.overview,
'price_ranges': program.price_ranges,
'subjects': SubjectSerializer(program.subjects, many=True).data,
'transcript_languages': [serialize_language_to_code(l) for l in program.transcript_languages],
}
self.assertDictEqual(serializer.data, expected)
def test_course_ordering(self):
"""
Verify that courses in a program are ordered by ascending run start date,
with ties broken by earliest run enrollment start date.
"""
request = make_request()
course_list = CourseFactory.create_batch(3)
# Create a course run with arbitrary start and empty enrollment_start.
CourseRunFactory(
course=course_list[2],
enrollment_start=None,
start=datetime(2014, 2, 1),
)
# Create a second run with matching start, but later enrollment_start.
CourseRunFactory(
course=course_list[1],
enrollment_start=datetime(2014, 1, 2),
start=datetime(2014, 2, 1),
)
# Create a third run with later start and enrollment_start.
CourseRunFactory(
course=course_list[0],
enrollment_start=datetime(2014, 2, 1),
start=datetime(2014, 3, 1),
)
program = ProgramFactory(courses=course_list)
serializer = ProgramSerializer(program, context={'request': request})
expected = ProgramCourseSerializer(
# The expected ordering is the reverse of course_list.
course_list[::-1],
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data
self.assertEqual(serializer.data['courses'], expected)
def test_course_ordering_with_exclusions(self):
"""
Verify that excluded course runs aren't used when ordering courses.
"""
def test_data(self):
request = make_request()
course_list = CourseFactory.create_batch(3)
# Create a course run with arbitrary start and empty enrollment_start.
# This run will be excluded from the program. If it wasn't excluded,
# the expected course ordering, by index, would be: 0, 2, 1.
excluded_run = CourseRunFactory(
course=course_list[0],
enrollment_start=None,
start=datetime(2014, 1, 1),
)
# Create a run with later start and empty enrollment_start.
CourseRunFactory(
course=course_list[2],
enrollment_start=None,
start=datetime(2014, 2, 1),
)
# Create a run with matching start, but later enrollment_start.
CourseRunFactory(
course=course_list[1],
enrollment_start=datetime(2014, 1, 2),
start=datetime(2014, 2, 1),
)
# Create a run with later start and enrollment_start.
CourseRunFactory(
course=course_list[0],
enrollment_start=datetime(2014, 2, 1),
start=datetime(2014, 3, 1),
)
program = ProgramFactory(courses=course_list, excluded_course_runs=[excluded_run])
program = self.create_program()
serializer = ProgramSerializer(program, context={'request': request})
expected = self.get_expected_data(program, request)
self.assertDictEqual(dict(serializer.data), expected)
expected = ProgramCourseSerializer(
# The expected ordering is the reverse of course_list.
course_list[::-1],
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data
self.assertEqual(serializer.data['courses'], expected)
def test_course_ordering_with_no_start(self):
def test_data_with_exclusions(self):
"""
Verify that a courses run with missing start date appears last when ordering courses.
Verify we can specify program excluded_course_runs and the serializers will
render the course_runs with exclusions
"""
request = make_request()
course_list = CourseFactory.create_batch(3)
program = self.create_program()
# Create a course run with arbitrary start and empty enrollment_start.
CourseRunFactory(
course=course_list[2],
enrollment_start=None,
start=datetime(2014, 2, 1),
)
excluded_course_run = program.courses.all()[0].course_runs.all()[0]
program.excluded_course_runs.add(excluded_course_run)
# Create a second run with matching start, but later enrollment_start.
CourseRunFactory(
course=course_list[1],
enrollment_start=datetime(2014, 1, 2),
start=datetime(2014, 2, 1),
)
# Create a third run with empty start and enrollment_start.
CourseRunFactory(
course=course_list[0],
enrollment_start=None,
start=None,
)
program = ProgramFactory(courses=course_list)
expected = self.get_expected_data(program, request)
serializer = ProgramSerializer(program, context={'request': request})
expected = ProgramCourseSerializer(
# The expected ordering is the reverse of course_list.
course_list[::-1],
many=True,
context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data
self.assertEqual(serializer.data['courses'], expected)
self.assertDictEqual(serializer.data, expected)
class ContainedCourseRunsSerializerTests(TestCase):
......@@ -935,7 +735,7 @@ class ProgramSearchSerializerTests(TestCase):
'partner': program.partner.short_code,
'authoring_organization_uuids': get_uuids(program.authoring_organizations.all()),
'subject_uuids': get_uuids([course.subjects for course in program.courses.all()]),
'staff_uuids': get_uuids([course.staff for course in program.course_runs.all()])
'staff_uuids': get_uuids([course.staff for course in program.course_runs])
}
def test_data(self):
......
......@@ -141,7 +141,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2')
CourseRunFactory(enrollment_end=enrollment_end, course=self.course)
with self.assertNumQueries(41):
with self.assertNumQueries(40):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True))
......
......@@ -19,7 +19,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(19):
with self.assertNumQueries(18):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course))
......@@ -28,7 +28,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
with self.assertNumQueries(25):
with self.assertNumQueries(24):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(
......@@ -55,6 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(38):
with self.assertNumQueries(35):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......@@ -40,15 +40,17 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """
program = ProgramFactory()
self.assert_retrieve_success(program)
with self.assertNumQueries(33):
self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory()
program = ProgramFactory(courses=[course])
self.assert_retrieve_success(program)
with self.assertNumQueries(55):
self.assert_retrieve_success(program)
def assert_list_results(self, url, expected):
def assert_list_results(self, url, expected, expected_query_count):
"""
Asserts the results serialized/returned at the URL matches those that are expected.
Args:
......@@ -62,7 +64,9 @@ class ProgramViewSetTests(APITestCase):
Returns:
None
"""
response = self.client.get(url)
with self.assertNumQueries(expected_query_count):
response = self.client.get(url)
self.assertEqual(
response.data['results'],
ProgramSerializer(expected, many=True, context={'request': self.request}).data
......@@ -72,17 +76,17 @@ class ProgramViewSetTests(APITestCase):
""" Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3)
expected.reverse()
self.assert_list_results(self.list_path, expected)
self.assert_list_results(self.list_path, expected, 14)
def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program])
self.assert_list_results(url, [program], 14)
url = self.list_path + '?type=bar'
self.assert_list_results(url, [])
self.assert_list_results(url, [], 4)
def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
......@@ -94,14 +98,14 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out.
ProgramFactory()
self.assert_list_results(url, expected)
self.assert_list_results(url, expected, 14)
@ddt.data(
(ProgramStatus.Unpublished, False),
(ProgramStatus.Active, True),
(ProgramStatus.Unpublished, False, 4),
(ProgramStatus.Active, True, 14),
)
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable):
def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='')
......@@ -110,4 +114,4 @@ class ProgramViewSetTests(APITestCase):
expected = programs if is_marketable else []
self.assertEqual(list(Program.objects.marketable()), expected)
self.assert_list_results(url, expected)
self.assert_list_results(url, expected, expected_query_count)
......@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset):
return queryset
def prefetch_related_objects_for_programs(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Program`.
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
course = serializers.PREFETCH_FIELDS['course'] + serializers.SELECT_RELATED_FIELDS['course']
course = ['courses__' + field for field in course]
queryset = queryset.prefetch_related(*course)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['program'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['program'])
return queryset
# pylint: disable=no-member
class CatalogViewSet(viewsets.ModelViewSet):
""" Catalog resource. """
......@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
""" Program resource. """
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+'
queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.serializer_class.prefetch_queryset()
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0))
......
import datetime
import itertools
import logging
from collections import defaultdict
from urllib.parse import urljoin
from uuid import uuid4
......@@ -635,57 +636,66 @@ class Program(TimeStampedModel):
@property
def course_runs(self):
"""
Warning! Only call this method after retrieving programs from `ProgramSerializer.prefetch_queryset()`.
Otherwise, this method will incur many, many queries when fetching related courses and course runs.
"""
excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()]
return CourseRun.objects.filter(course__programs=self).exclude(id__in=excluded_course_run_ids)
for course in self.courses.all():
for run in course.course_runs.all():
if run.id not in excluded_course_run_ids:
yield run
@property
def languages(self):
course_runs = self.course_runs.select_related('language')
return set(course_run.language for course_run in course_runs if course_run.language is not None)
return set(course_run.language for course_run in self.course_runs if course_run.language is not None)
@property
def transcript_languages(self):
course_runs = self.course_runs.prefetch_related('transcript_languages')
languages = [list(course_run.transcript_languages.all()) for course_run in course_runs]
languages = [course_run.transcript_languages.all() for course_run in self.course_runs]
languages = itertools.chain.from_iterable(languages)
return set(languages)
@property
def subjects(self):
courses = self.courses.prefetch_related('subjects')
subjects = [list(course.subjects.all()) for course in courses]
subjects = [course.subjects.all() for course in self.courses.all()]
subjects = itertools.chain.from_iterable(subjects)
return set(subjects)
@property
def seats(self):
applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True)
return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) \
.select_related('currency')
applicable_seat_types = set(seat_type.slug for seat_type in self.type.applicable_seat_types.all())
for run in self.course_runs:
for seat in run.seats.all():
if seat.type in applicable_seat_types:
yield seat
@property
def seat_types(self):
return set(self.seats.values_list('type', flat=True))
return set(seat.type for seat in self.seats)
@property
def price_ranges(self):
seats = self.seats.values('currency').annotate(models.Min('price'), models.Max('price'))
price_ranges = []
currencies = defaultdict(list)
for seat in self.seats:
currencies[seat.currency].append(seat.price)
for seat in seats:
price_ranges = []
for currency, prices in currencies.items():
price_ranges.append({
'currency': seat['currency'],
'min': seat['price__min'],
'max': seat['price__max'],
'currency': currency.code,
'min': min(prices),
'max': max(prices),
})
return price_ranges
@property
def start(self):
""" Start datetime, calculated by determining the earliest start datetime of all related course runs. """
course_runs = self.course_runs
if course_runs:
if self.course_runs:
start_dates = [course_run.start for course_run in self.course_runs if course_run.start]
if start_dates:
......@@ -695,7 +705,7 @@ class Program(TimeStampedModel):
@property
def staff(self):
staff = [list(course_run.staff.all()) for course_run in self.course_runs]
staff = [course_run.staff.all() for course_run in self.course_runs]
staff = itertools.chain.from_iterable(staff)
return set(staff)
......
......@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin):
return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()]
def prepare_staff_uuids(self, obj):
return [str(staff.uuid) for course_run in obj.course_runs.all() for staff in course_run.staff.all()]
return [str(staff.uuid) for course_run in obj.course_runs for staff in course_run.staff.all()]
def prepare_credit_backing_organizations(self, obj):
return self._prepare_organizations(obj.credit_backing_organizations.all())
......
......@@ -100,7 +100,7 @@ class AdminTests(TestCase):
""" Verify that course selection page with posting the data. """
self.assertEqual(1, self.program.excluded_course_runs.all().count())
self.assertEqual(3, len(self.program.course_runs.all()))
self.assertEqual(3, sum(1 for _ in self.program.course_runs))
params = {
'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id],
......@@ -114,7 +114,7 @@ class AdminTests(TestCase):
target_status_code=200
)
self.assertEqual(2, self.program.excluded_course_runs.all().count())
self.assertEqual(2, len(self.program.course_runs.all()))
self.assertEqual(2, sum(1 for _ in self.program.course_runs))
def test_page_with_post_without_course_run(self):
""" Verify that course selection page without posting any selected excluded check run. """
......@@ -132,7 +132,7 @@ class AdminTests(TestCase):
target_status_code=200
)
self.assertEqual(0, self.program.excluded_course_runs.all().count())
self.assertEqual(4, len(self.program.course_runs.all()))
self.assertEqual(4, sum(1 for _ in self.program.course_runs))
response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,)))
self.assertNotContains(response, '<input checked="checked")')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment