Commit 67b1c404 by Clinton Blackburn Committed by GitHub

Improved performance of the programs endpoint (#333)

Pre-fetching greatly reduces the number of queries. Locally, this has resulted in a 78% decrease in queries.

ECOM-5559
parent 4eaf27da
# pylint: disable=abstract-method # pylint: disable=abstract-method
from datetime import datetime import datetime
import json import json
from urllib.parse import urlencode from urllib.parse import urlencode
import pytz
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer
import pytz
from rest_framework import serializers from rest_framework import serializers
from rest_framework.fields import DictField from rest_framework.fields import DictField
from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer from taggit_serializer.serializers import TagListSerializerField, TaggitSerializer
...@@ -67,6 +67,29 @@ PROGRAM_SEARCH_FIELDS = BASE_PROGRAM_FIELDS + ('authoring_organizations', 'autho ...@@ -67,6 +67,29 @@ PROGRAM_SEARCH_FIELDS = BASE_PROGRAM_FIELDS + ('authoring_organizations', 'autho
'subject_uuids', 'staff_uuids',) 'subject_uuids', 'staff_uuids',)
PROGRAM_FACET_FIELDS = BASE_PROGRAM_FIELDS + ('organizations',) PROGRAM_FACET_FIELDS = BASE_PROGRAM_FIELDS + ('organizations',)
PREFETCH_FIELDS = {
'course_run': [
'course__partner', 'course__level_type', 'course__programs', 'course__programs__type',
'course__programs__partner', 'seats', 'transcript_languages', 'seats__currency', 'staff',
'staff__position', 'staff__position__organization', 'language',
],
'course': [
'level_type', 'video', 'programs', 'course_runs', 'subjects', 'prerequisites', 'expected_learning_items',
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner',
'sponsoring_organizations', 'sponsoring_organizations__tags', 'sponsoring_organizations__partner',
],
'program': [
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner',
'excluded_course_runs', 'courses', 'courses__authoring_organizations', 'courses__course_runs',
],
}
SELECT_RELATED_FIELDS = {
'course': ['level_type', 'video', ],
'course_run': ['course', 'language', 'video', ],
'program': ['type', 'video', 'partner', ],
}
def get_marketing_url_for_user(user, marketing_url): def get_marketing_url_for_user(user, marketing_url):
""" """
...@@ -352,10 +375,12 @@ class ProgramCourseSerializer(CourseSerializer): ...@@ -352,10 +375,12 @@ class ProgramCourseSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField() course_runs = serializers.SerializerMethodField()
def get_course_runs(self, course): def get_course_runs(self, course):
program = self.context['program'] course_runs = self.context['course_runs']
course_runs = program.course_runs.filter(course=course) course_runs = [course_run for course_run in course_runs if course_run.course == course]
if self.context.get('published_course_runs_only'): if self.context.get('published_course_runs_only'):
course_runs = course_runs.filter(status=CourseRunStatus.Published) course_runs = [course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published]
return CourseRunSerializer( return CourseRunSerializer(
course_runs, course_runs,
many=True, many=True,
...@@ -387,7 +412,7 @@ class ProgramSerializer(serializers.ModelSerializer): ...@@ -387,7 +412,7 @@ class ProgramSerializer(serializers.ModelSerializer):
staff = PersonSerializer(many=True) staff = PersonSerializer(many=True)
def get_courses(self, program): def get_courses(self, program):
courses = self.sort_courses(program) courses, course_runs = self.sort_courses(program)
course_serializer = ProgramCourseSerializer( course_serializer = ProgramCourseSerializer(
courses, courses,
...@@ -396,6 +421,7 @@ class ProgramSerializer(serializers.ModelSerializer): ...@@ -396,6 +421,7 @@ class ProgramSerializer(serializers.ModelSerializer):
'request': self.context.get('request'), 'request': self.context.get('request'),
'program': program, 'program': program,
'published_course_runs_only': self.context.get('published_course_runs_only'), 'published_course_runs_only': self.context.get('published_course_runs_only'),
'course_runs': course_runs,
} }
) )
...@@ -411,10 +437,14 @@ class ProgramSerializer(serializers.ModelSerializer): ...@@ -411,10 +437,14 @@ class ProgramSerializer(serializers.ModelSerializer):
course_runs should never be empty. If it is, key functions in this method attempting to find the course_runs should never be empty. If it is, key functions in this method attempting to find the
min of an empty sequence will raise a ValueError. min of an empty sequence will raise a ValueError.
""" """
course_runs = program.course_runs.select_related(*SELECT_RELATED_FIELDS['course_run'])
course_runs = course_runs.prefetch_related(*PREFETCH_FIELDS['course_run'])
course_runs = list(course_runs)
def min_run_enrollment_start(course): def min_run_enrollment_start(course):
# Enrollment starts may be empty. When this is the case, we make the same assumption as # Enrollment starts may be empty. When this is the case, we make the same assumption as
# the LMS: no enrollment_start is equivalent to (offset-aware) datetime.min. # the LMS: no enrollment_start is equivalent to (offset-aware) datetime.datetime.min.
min_datetime = datetime.min.replace(tzinfo=pytz.UTC) min_datetime = datetime.datetime.min.replace(tzinfo=pytz.UTC)
# Course runs excluded from the program are excluded here, too. # Course runs excluded from the program are excluded here, too.
# #
...@@ -423,12 +453,14 @@ class ProgramSerializer(serializers.ModelSerializer): ...@@ -423,12 +453,14 @@ class ProgramSerializer(serializers.ModelSerializer):
# values, while SQLite does the opposite. # values, while SQLite does the opposite.
# #
# For more, refer to https://docs.djangoproject.com/en/1.10/ref/models/querysets/#latest. # For more, refer to https://docs.djangoproject.com/en/1.10/ref/models/querysets/#latest.
run = min(program.course_runs.filter(course=course), key=lambda run: run.enrollment_start or min_datetime) _course_runs = [course_run for course_run in course_runs if course_run.course == course]
run = min(_course_runs, key=lambda run: run.enrollment_start or min_datetime)
return run.enrollment_start or min_datetime return run.enrollment_start or min_datetime
def min_run_start(course): def min_run_start(course):
run = min(program.course_runs.filter(course=course), key=lambda run: run.start) _course_runs = [course_run for course_run in course_runs if course_run.course == course]
run = min(_course_runs, key=lambda run: run.start)
return run.start return run.start
...@@ -436,7 +468,7 @@ class ProgramSerializer(serializers.ModelSerializer): ...@@ -436,7 +468,7 @@ class ProgramSerializer(serializers.ModelSerializer):
courses.sort(key=min_run_enrollment_start) courses.sort(key=min_run_enrollment_start)
courses.sort(key=min_run_start) courses.sort(key=min_run_start)
return courses return courses, course_runs
class Meta: class Meta:
model = Program model = Program
......
...@@ -191,7 +191,7 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -191,7 +191,7 @@ class ProgramCourseSerializerTests(TestCase):
serializer = ProgramCourseSerializer( serializer = ProgramCourseSerializer(
self.course_list, self.course_list,
many=True, many=True,
context={'request': self.request, 'program': self.program} context={'request': self.request, 'program': self.program, 'course_runs': self.program.course_runs}
) )
expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
...@@ -204,7 +204,7 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -204,7 +204,7 @@ class ProgramCourseSerializerTests(TestCase):
serializer = ProgramCourseSerializer( serializer = ProgramCourseSerializer(
self.course_list, self.course_list,
many=True, many=True,
context={'request': self.request, 'program': self.program} context={'request': self.request, 'program': self.program, 'course_runs': self.program.course_runs}
) )
expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data expected = CourseSerializer(self.course_list, many=True, context={'request': self.request}).data
...@@ -221,7 +221,7 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -221,7 +221,7 @@ class ProgramCourseSerializerTests(TestCase):
excluded_runs.append(course_runs[0]) excluded_runs.append(course_runs[0])
program = ProgramFactory(courses=[course], excluded_course_runs=excluded_runs) program = ProgramFactory(courses=[course], excluded_course_runs=excluded_runs)
serializer_context = {'request': self.request, 'program': program} serializer_context = {'request': self.request, 'program': program, 'course_runs': program.course_runs}
serializer = ProgramCourseSerializer(course, context=serializer_context) serializer = ProgramCourseSerializer(course, context=serializer_context)
expected = CourseSerializer(course, context=serializer_context).data expected = CourseSerializer(course, context=serializer_context).data
...@@ -251,7 +251,8 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -251,7 +251,8 @@ class ProgramCourseSerializerTests(TestCase):
context={ context={
'request': self.request, 'request': self.request,
'program': self.program, 'program': self.program,
'published_course_runs_only': published_course_runs_only 'published_course_runs_only': published_course_runs_only,
'course_runs': self.program.course_runs
} }
) )
validate_data = serializer.data validate_data = serializer.data
...@@ -320,7 +321,7 @@ class ProgramSerializerTests(TestCase): ...@@ -320,7 +321,7 @@ class ProgramSerializerTests(TestCase):
'courses': ProgramCourseSerializer( 'courses': ProgramCourseSerializer(
program.courses, program.courses,
many=True, many=True,
context={'request': request, 'program': program} context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data, ).data,
'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data, 'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data,
'credit_backing_organizations': OrganizationSerializer( 'credit_backing_organizations': OrganizationSerializer(
...@@ -386,7 +387,7 @@ class ProgramSerializerTests(TestCase): ...@@ -386,7 +387,7 @@ class ProgramSerializerTests(TestCase):
'courses': ProgramCourseSerializer( 'courses': ProgramCourseSerializer(
program.courses, program.courses,
many=True, many=True,
context={'request': request, 'program': program} context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data, ).data,
'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data, 'corporate_endorsements': CorporateEndorsementSerializer(program.corporate_endorsements, many=True).data,
'credit_backing_organizations': OrganizationSerializer( 'credit_backing_organizations': OrganizationSerializer(
...@@ -447,7 +448,7 @@ class ProgramSerializerTests(TestCase): ...@@ -447,7 +448,7 @@ class ProgramSerializerTests(TestCase):
# The expected ordering is the reverse of course_list. # The expected ordering is the reverse of course_list.
course_list[::-1], course_list[::-1],
many=True, many=True,
context={'request': request, 'program': program} context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data ).data
self.assertEqual(serializer.data['courses'], expected) self.assertEqual(serializer.data['courses'], expected)
...@@ -496,7 +497,7 @@ class ProgramSerializerTests(TestCase): ...@@ -496,7 +497,7 @@ class ProgramSerializerTests(TestCase):
# The expected ordering is the reverse of course_list. # The expected ordering is the reverse of course_list.
course_list[::-1], course_list[::-1],
many=True, many=True,
context={'request': request, 'program': program} context={'request': request, 'program': program, 'course_runs': program.course_runs}
).data ).data
self.assertEqual(serializer.data['courses'], expected) self.assertEqual(serializer.data['courses'], expected)
......
...@@ -34,24 +34,6 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun, Part ...@@ -34,24 +34,6 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun, Part
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
PREFETCH_FIELDS = {
'course_run': [
'course__partner', 'course__level_type', 'course__programs', 'course__programs__type',
'course__programs__partner', 'seats', 'transcript_languages', 'seats__currency', 'staff',
'staff__position', 'staff__position__organization', 'language',
],
'course': [
'level_type', 'video', 'programs', 'course_runs', 'subjects', 'prerequisites', 'expected_learning_items',
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner',
'sponsoring_organizations', 'sponsoring_organizations__tags', 'sponsoring_organizations__partner',
],
}
SELECT_RELATED_FIELDS = {
'course': ['level_type', 'video', ],
'course_run': ['course', 'language', 'video', ],
}
def prefetch_related_objects_for_courses(queryset): def prefetch_related_objects_for_courses(queryset):
""" """
...@@ -69,13 +51,33 @@ def prefetch_related_objects_for_courses(queryset): ...@@ -69,13 +51,33 @@ def prefetch_related_objects_for_courses(queryset):
Returns: Returns:
QuerySet QuerySet
""" """
_prefetch_fields = serializers.PREFETCH_FIELDS
_select_related_fields = serializers.SELECT_RELATED_FIELDS
# Prefetch the data for the related course runs # Prefetch the data for the related course runs
course_run_prefetch_fields = PREFETCH_FIELDS['course_run'] + SELECT_RELATED_FIELDS['course_run'] course_run_prefetch_fields = _prefetch_fields['course_run'] + _select_related_fields['course_run']
course_run_prefetch_fields = ['course_runs__' + field for field in course_run_prefetch_fields] course_run_prefetch_fields = ['course_runs__' + field for field in course_run_prefetch_fields]
queryset = queryset.prefetch_related(*course_run_prefetch_fields) queryset = queryset.prefetch_related(*course_run_prefetch_fields)
queryset = queryset.select_related(*SELECT_RELATED_FIELDS['course']) queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*PREFETCH_FIELDS['course']) queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset
def prefetch_related_objects_for_programs(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Program`.
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
course = serializers.PREFETCH_FIELDS['course'] + serializers.SELECT_RELATED_FIELDS['course']
course = ['courses__' + field for field in course]
queryset = queryset.prefetch_related(*course)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['program'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['program'])
return queryset return queryset
...@@ -199,9 +201,9 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -199,9 +201,9 @@ class CatalogViewSet(viewsets.ModelViewSet):
course_runs = CourseRun.objects.filter(course__in=courses).active().marketable() course_runs = CourseRun.objects.filter(course__in=courses).active().marketable()
# We use select_related and prefetch_related to decrease our database query count # We use select_related and prefetch_related to decrease our database query count
course_runs = course_runs.select_related(*SELECT_RELATED_FIELDS['course_run']) course_runs = course_runs.select_related(*serializers.SELECT_RELATED_FIELDS['course_run'])
prefetch_fields = ['course__' + field for field in PREFETCH_FIELDS['course']] prefetch_fields = ['course__' + field for field in serializers.PREFETCH_FIELDS['course']]
prefetch_fields += PREFETCH_FIELDS['course_run'] prefetch_fields += serializers.PREFETCH_FIELDS['course_run']
course_runs = course_runs.prefetch_related(*prefetch_fields) course_runs = course_runs.prefetch_related(*prefetch_fields)
serializer = serializers.FlattenedCourseRunWithCourseSerializer( serializer = serializers.FlattenedCourseRunWithCourseSerializer(
...@@ -295,8 +297,8 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -295,8 +297,8 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
return qs return qs
else: else:
queryset = super(CourseRunViewSet, self).get_queryset().filter(course__partner=partner) queryset = super(CourseRunViewSet, self).get_queryset().filter(course__partner=partner)
queryset = queryset.select_related(*SELECT_RELATED_FIELDS['course_run']) queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['course_run'])
queryset = queryset.prefetch_related(*PREFETCH_FIELDS['course_run']) queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['course_run'])
return queryset return queryset
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
...@@ -391,7 +393,7 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -391,7 +393,7 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
""" Program resource. """ """ Program resource. """
lookup_field = 'uuid' lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+' lookup_value_regex = '[0-9a-f-]+'
queryset = Program.objects.all() queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend,)
......
...@@ -5,6 +5,7 @@ from urllib.parse import urljoin ...@@ -5,6 +5,7 @@ from urllib.parse import urljoin
from uuid import uuid4 from uuid import uuid4
import pytz import pytz
import waffle
from django.db import models, transaction from django.db import models, transaction
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
...@@ -15,7 +16,6 @@ from haystack.query import SearchQuerySet ...@@ -15,7 +16,6 @@ from haystack.query import SearchQuerySet
from sortedm2m.fields import SortedManyToManyField from sortedm2m.fields import SortedManyToManyField
from stdimage.models import StdImageField from stdimage.models import StdImageField
from taggit.managers import TaggableManager from taggit.managers import TaggableManager
import waffle
from course_discovery.apps.core.models import Currency, Partner from course_discovery.apps.core.models import Currency, Partner
from course_discovery.apps.course_metadata.choices import CourseRunStatus, CourseRunPacing, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, CourseRunPacing, ProgramStatus
...@@ -639,24 +639,28 @@ class Program(TimeStampedModel): ...@@ -639,24 +639,28 @@ class Program(TimeStampedModel):
@property @property
def languages(self): def languages(self):
return set(course_run.language for course_run in self.course_runs if course_run.language is not None) course_runs = self.course_runs.select_related('language')
return set(course_run.language for course_run in course_runs if course_run.language is not None)
@property @property
def transcript_languages(self): def transcript_languages(self):
languages = [list(course_run.transcript_languages.all()) for course_run in self.course_runs] course_runs = self.course_runs.prefetch_related('transcript_languages')
languages = [list(course_run.transcript_languages.all()) for course_run in course_runs]
languages = itertools.chain.from_iterable(languages) languages = itertools.chain.from_iterable(languages)
return set(languages) return set(languages)
@property @property
def subjects(self): def subjects(self):
subjects = [list(course.subjects.all()) for course in self.courses.all()] courses = self.courses.prefetch_related('subjects')
subjects = [list(course.subjects.all()) for course in courses]
subjects = itertools.chain.from_iterable(subjects) subjects = itertools.chain.from_iterable(subjects)
return set(subjects) return set(subjects)
@property @property
def seats(self): def seats(self):
applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True) applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True)
return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) \
.select_related('currency')
@property @property
def seat_types(self): def seat_types(self):
...@@ -699,8 +703,7 @@ class Program(TimeStampedModel): ...@@ -699,8 +703,7 @@ class Program(TimeStampedModel):
return self.status == ProgramStatus.Active return self.status == ProgramStatus.Active
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if waffle.switch_is_active('publish_program_to_marketing_site') and \ if waffle.switch_is_active('publish_program_to_marketing_site') and self.partner.has_marketing_site:
self.partner.has_marketing_site:
# Before save, get from database the existing data if exists # Before save, get from database the existing data if exists
existing_program = None existing_program = None
if self.id: if self.id:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment