Commit 2e556a19 by Renzo Lucioni

Prevent course API from discarding prefetched data

filter() and exclude() calls made by the serializer after prefetching were causing prefetched data to be discarded. Moving that filtering ahead of the prefetch prevents prefetched data from being discarded and saves many duplicate queries.
parent 3e42af4f
......@@ -449,6 +449,12 @@ class CourseRunWithProgramsSerializer(CourseRunSerializer):
"""A ``CourseRunSerializer`` which includes programs derived from parent course."""
programs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None):
queryset = super().prefetch_queryset(queryset=queryset)
return queryset.prefetch_related('course__programs')
def get_programs(self, obj):
programs = []
# Filter out non-deleted programs which this course_run is part of the program course_run exclusion
......@@ -482,14 +488,14 @@ class MinimalCourseSerializer(TimestampModelSerializer):
image = ImageField(read_only=True, source='card_image_url')
@classmethod
def prefetch_queryset(cls, queryset=None):
def prefetch_queryset(cls, queryset=None, course_runs=None):
# Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.all()
return queryset.select_related('partner').prefetch_related(
'authoring_organizations',
Prefetch('course_runs', queryset=MinimalCourseRunSerializer.prefetch_queryset()),
Prefetch('course_runs', queryset=MinimalCourseRunSerializer.prefetch_queryset(queryset=course_runs)),
)
class Meta:
......@@ -510,7 +516,7 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None):
def prefetch_queryset(cls, queryset=None, course_runs=None):
# Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.all()
......@@ -519,7 +525,7 @@ class CourseSerializer(MinimalCourseSerializer):
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset()),
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
......@@ -544,36 +550,26 @@ class CourseWithProgramsSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField()
programs = serializers.SerializerMethodField()
def get_course_runs(self, course):
# exclude() isn't used to avoid discarding prefetched results.
course_runs = [course_run for course_run in course.course_runs.all() if not course_run.hidden]
if self.context.get('marketable_course_runs_only'):
# A client requesting marketable_course_runs_only should only receive course runs
# that are published, have seats, and can still be enrolled in. All other course runs
# should be excluded. As an unfortunate side-effect of the way we've marketed course
# runs in the past - a course run could be marketed despite enrollment in that run being
# closed - achieving this requires applying both the marketable and active filters.
# TODO: These queryset methods chain filter() and exclude() calls, causing
# prefetched results to be discarded.
course_runs = course.course_runs.exclude(hidden=True).marketable().active()
if self.context.get('marketable_enrollable_course_runs_with_archived'):
# Same as "marketable_course_runs_only", but includes courses with an end date in the past
# TODO: These queryset methods chain filter() and exclude() calls, causing
# prefetched results to be discarded.
course_runs = course.course_runs.exclude(hidden=True).marketable().enrollable()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.all()
if self.context.get('published_course_runs_only'):
# filter() isn't used to avoid discarding prefetched results.
course_runs = [
course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published
]
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
def get_course_runs(self, course):
return CourseRunSerializer(
course_runs,
course.course_runs,
many=True,
context={
'request': self.context.get('request'),
......@@ -603,33 +599,25 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None):
def prefetch_queryset(cls, queryset=None, course_runs=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.all()
available_course_runs = CourseRun.objects.active().enrollable().marketable()
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch(
'course_runs',
queryset=CourseRunSerializer.prefetch_queryset(queryset=available_course_runs),
# Using to_attr is recommended when filtering down a prefetch
# result as it is less ambiguous than storing a filtered result
# in the related manager’s cache and accessing it via all().
to_attr='available_course_runs'
),
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
def get_course_runs(self, course):
return CourseRunSerializer(
course.available_course_runs,
course.course_runs,
many=True,
context=self.context
).data
......
......@@ -201,99 +201,6 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
)
self.assertEqual(serializer.data, self.get_expected_data(self.course, self.request))
@ddt.data(0, 1)
def test_marketable_course_runs_only(self, marketable_course_runs_only):
"""
Verify that the marketable_course_runs_only option is respected, restricting returned
course runs to those that are published, have seats, and can still be enrolled in.
"""
enrollable_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=10),
enrollment_start=None,
enrollment_end=None,
course=self.course
)
SeatFactory(course_run=enrollable_course_run)
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course)
SeatFactory(course_run=unpublished_course_run)
CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
closed_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=10),
course=self.course
)
SeatFactory(course_run=closed_course_run)
serializer = self.serializer_class(
self.course,
context={'request': self.request, 'marketable_course_runs_only': marketable_course_runs_only}
)
self.assertEqual(
len(serializer.data['course_runs']),
1 if marketable_course_runs_only else 4
)
def test_marketable_enrollable_course_runs_with_archived(self):
"""
Verify that the marketable_enrollable_course_runs_with_archived option is respected, restricting returned
course runs to those that are published, have seats, and can still be enrolled in
(including courses with an end date in the past.)
"""
enrollable_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=10),
enrollment_start=None,
enrollment_end=None,
course=self.course
)
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course)
CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
archived_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=10),
enrollment_start=None,
enrollment_end=None,
course=self.course
)
SeatFactory(course_run=unpublished_course_run)
SeatFactory(course_run=enrollable_course_run)
SeatFactory(course_run=archived_course_run)
context = {
'request': self.request,
'marketable_enrollable_course_runs_with_archived': 1
}
course_serializer = self.serializer_class(
self.course,
context=context
)
course_run_keys = [course_run['key'] for course_run in course_serializer.data['course_runs']]
# order doesn't matter
assert sorted(course_run_keys) == sorted([enrollable_course_run.key, archived_course_run.key])
@ddt.data(0, 1)
def test_published_course_runs_only(self, published_course_runs_only):
"""
Test that the published_course_runs_only flag hides unpublished course runs
"""
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished)
published_course_run = CourseRunFactory(status=CourseRunStatus.Published)
self.course.course_runs.add(unpublished_course_run, published_course_run)
serializer = self.serializer_class(
self.course,
context={'request': self.request, 'published_course_runs_only': published_course_runs_only}
)
self.assertEqual(len(serializer.data['course_runs']), 2 - published_course_runs_only)
class MinimalCourseRunSerializerTests(TestCase):
serializer_class = MinimalCourseRunSerializer
......
......@@ -170,15 +170,16 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# This run has no seats, but we still expect its parent course
# to be included.
CourseRunFactory(course=course)
filtered_course_run = CourseRunFactory(course=course)
with self.assertNumQueries(18):
response = self.client.get(url)
# Prefetched results are assigned to a custom attribute.
course.available_course_runs = [course_run]
assert response.status_code == 200
# Emulate prefetching behavior.
filtered_course_run.delete()
assert response.data['results'] == self.serialize_catalog_course([course], many=True)
# Any course appearing in the response must have at least one serialized run.
......
......@@ -51,7 +51,7 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key})
with self.assertNumQueries(13):
with self.assertNumQueries(10):
response = self.client.get(url)
assert response.status_code == 200
assert response.data.get('programs') == []
......@@ -78,7 +78,7 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key})
with self.assertNumQueries(13):
with self.assertNumQueries(11):
response = self.client.get(url)
assert response.status_code == 200
assert response.data.get('programs') == []
......@@ -130,10 +130,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
assert response.status_code == 403
def test_list(self):
""" Verify the endpoint returns a list of all catalogs. """
""" Verify the endpoint returns a list of all course runs. """
url = reverse('api:v1:course_run-list')
with self.assertNumQueries(12):
with self.assertNumQueries(13):
response = self.client.get(url)
assert response.status_code == 200
......@@ -143,10 +143,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
)
def test_list_sorted_by_course_start_date(self):
""" Verify the endpoint returns a list of all catalogs sorted by course start date. """
""" Verify the endpoint returns a list of all course runs sorted by start date. """
url = '{root}?ordering=start'.format(root=reverse('api:v1:course_run-list'))
with self.assertNumQueries(12):
with self.assertNumQueries(13):
response = self.client.get(url)
assert response.status_code == 200
......
......@@ -80,7 +80,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
SeatFactory(course_run=unpublished_course_run)
# Published course run with no seats.
CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
no_seats_course_run = CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
# Published course run with a seat and an end date in the past.
closed_course_run = CourseRunFactory(
......@@ -94,14 +94,14 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
url = '{}?marketable_course_runs_only={}'.format(url, marketable_course_runs_only)
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.data,
self.serialize_course(
self.course,
extra_context={'marketable_course_runs_only': marketable_course_runs_only}
)
)
assert response.status_code == 200
if marketable_course_runs_only:
# Emulate prefetching behavior.
for course_run in (unpublished_course_run, no_seats_course_run, closed_course_run):
course_run.delete()
assert response.data == self.serialize_course(self.course)
@ddt.data(1, 0)
def test_marketable_enrollable_course_runs_with_archived(self, marketable_enrollable_course_runs_with_archived):
......@@ -111,13 +111,17 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
past = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=2)
future = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
CourseRunFactory(enrollment_start=None, enrollment_end=future, course=self.course)
CourseRunFactory(enrollment_start=None, enrollment_end=None, course=self.course)
CourseRunFactory(
enrollment_start=past, enrollment_end=future, course=self.course
)
CourseRunFactory(enrollment_start=future, course=self.course)
CourseRunFactory(enrollment_end=past, course=self.course)
course_run = CourseRunFactory(enrollment_start=None, enrollment_end=future, course=self.course)
SeatFactory(course_run=course_run)
filtered_course_runs = [
CourseRunFactory(enrollment_start=None, enrollment_end=None, course=self.course),
CourseRunFactory(
enrollment_start=past, enrollment_end=future, course=self.course
),
CourseRunFactory(enrollment_start=future, course=self.course),
CourseRunFactory(enrollment_end=past, course=self.course),
]
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url = '{}?marketable_enrollable_course_runs_with_archived={}'.format(
......@@ -126,12 +130,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
response = self.client.get(url)
assert response.status_code == 200
assert response.data == self.serialize_course(
self.course,
extra_context={
'marketable_enrollable_course_runs_with_archived': marketable_enrollable_course_runs_with_archived
}
)
if marketable_enrollable_course_runs_with_archived:
# Emulate prefetching behavior.
for course_run in filtered_course_runs:
course_run.delete()
assert response.data == self.serialize_course(self.course)
@ddt.data(1, 0)
def test_get_include_published_course_run(self, published_course_runs_only):
......@@ -140,15 +145,20 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
the 'published_course_runs_only' flag is set to True
"""
CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course)
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url = '{}?published_course_runs_only={}'.format(url, published_course_runs_only)
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.data,
self.serialize_course(self.course, extra_context={'published_course_runs_only': published_course_runs_only})
)
assert response.status_code == 200
if published_course_runs_only:
# Emulate prefetching behavior.
unpublished_course_run.delete()
assert response.data == self.serialize_course(self.course)
def test_list(self):
""" Verify the endpoint returns a list of all courses. """
......@@ -170,7 +180,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(58):
with self.assertNumQueries(37):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
......@@ -91,7 +91,12 @@ class CatalogViewSet(viewsets.ModelViewSet):
"""
catalog = self.get_object()
queryset = catalog.courses().available()
queryset = serializers.CatalogCourseSerializer.prefetch_queryset(queryset=queryset)
course_runs = CourseRun.objects.active().enrollable().marketable()
queryset = serializers.CatalogCourseSerializer.prefetch_queryset(
queryset=queryset,
course_runs=course_runs
)
page = self.paginate_queryset(queryset)
serializer = serializers.CatalogCourseSerializer(page, many=True, context={'request': request})
......
......@@ -40,7 +40,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
paramType: query
multiple: false
"""
q = self.request.query_params.get('q', None)
q = self.request.query_params.get('q')
partner = self.get_partner()
if q:
......@@ -50,9 +50,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
return qs
else:
queryset = super(CourseRunViewSet, self).get_queryset().filter(course__partner=partner)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['course_run'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['course_run'])
return queryset
return self.get_serializer_class().prefetch_queryset(queryset=queryset)
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
......
......@@ -6,8 +6,9 @@ from rest_framework.permissions import IsAuthenticated
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import ProxiedPagination
from course_discovery.apps.api.v1.views import get_query_param
from course_discovery.apps.course_metadata.choices import CourseRunStatus
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX
from course_discovery.apps.course_metadata.models import Course
from course_discovery.apps.course_metadata.models import Course, CourseRun
# pylint: disable=no-member
......@@ -26,19 +27,34 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class = ProxiedPagination
def get_queryset(self):
q = self.request.query_params.get('q', None)
q = self.request.query_params.get('q')
if q:
queryset = Course.search(q)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset)
else:
queryset = self.get_serializer_class().prefetch_queryset()
course_runs = CourseRun.objects.exclude(hidden=True)
if get_query_param(self.request, 'marketable_course_runs_only'):
course_runs = course_runs.marketable().active()
if get_query_param(self.request, 'marketable_enrollable_course_runs_with_archived'):
course_runs = course_runs.marketable().enrollable()
if get_query_param(self.request, 'published_course_runs_only'):
course_runs = course_runs.filter(status=CourseRunStatus.Published)
queryset = self.get_serializer_class().prefetch_queryset(
queryset=self.queryset,
course_runs=course_runs
)
return queryset.order_by(Lower('key'))
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
query_params = ['exclude_utm', 'include_deleted_programs', 'marketable_course_runs_only',
'marketable_enrollable_course_runs_with_archived', 'published_course_runs_only']
query_params = ['exclude_utm', 'include_deleted_programs']
for query_param in query_params:
context[query_param] = get_query_param(self.request, query_param)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment