Commit a4495712 by Renzo Lucioni

Prevent serializer from discarding prefetched data

Chained calls to queryset methods like filter() and exclude() imply a new database query and will cause prefetched results to be ignored. This change is an incremental improvement which more than halves the query count for requests which don't use the marketable_course_runs_only or marketable_enrollable_course_runs_with_archived querystring parameters (from 51 to 22 queries on a system with 1 course and 3 nested course runs). Reducing the query count for requests which do use those parameters requires filtering in Python which mimics the active, marketable, and enrollable queryset methods.
parent f9e64f3c
...@@ -258,10 +258,11 @@ class PersonSerializer(serializers.ModelSerializer): ...@@ -258,10 +258,11 @@ class PersonSerializer(serializers.ModelSerializer):
return person return person
def get_social_network_url(self, url_type, obj): def get_social_network_url(self, url_type, obj):
social_network = obj.person_networks.filter(type=url_type).first() # filter() isn't used to avoid discarding prefetched results.
social_networks = [network for network in obj.person_networks.all() if network.type == url_type]
if social_network: if social_networks:
return social_network.value return social_networks[0].value
def get_urls(self, obj): def get_urls(self, obj):
return { return {
...@@ -531,7 +532,8 @@ class CourseWithProgramsSerializer(CourseSerializer): ...@@ -531,7 +532,8 @@ class CourseWithProgramsSerializer(CourseSerializer):
programs = serializers.SerializerMethodField() programs = serializers.SerializerMethodField()
def get_course_runs(self, course): def get_course_runs(self, course):
course_runs = course.course_runs.exclude(hidden=True) # exclude() isn't used to avoid discarding prefetched results.
course_runs = [course_run for course_run in course.course_runs.all() if not course_run.hidden]
if self.context.get('marketable_course_runs_only'): if self.context.get('marketable_course_runs_only'):
# A client requesting marketable_course_runs_only should only receive course runs # A client requesting marketable_course_runs_only should only receive course runs
...@@ -539,14 +541,23 @@ class CourseWithProgramsSerializer(CourseSerializer): ...@@ -539,14 +541,23 @@ class CourseWithProgramsSerializer(CourseSerializer):
# should be excluded. As an unfortunate side-effect of the way we've marketed course # should be excluded. As an unfortunate side-effect of the way we've marketed course
# runs in the past - a course run could be marketed despite enrollment in that run being # runs in the past - a course run could be marketed despite enrollment in that run being
# closed - achieving this requires applying both the marketable and active filters. # closed - achieving this requires applying both the marketable and active filters.
course_runs = course_runs.marketable().active()
# TODO: These queryset methods chain filter() and exclude() calls, causing
# prefetched results to be discarded.
course_runs = course.course_runs.exclude(hidden=True).marketable().active()
if self.context.get('marketable_enrollable_course_runs_with_archived'): if self.context.get('marketable_enrollable_course_runs_with_archived'):
# Same as "marketable_course_runs_only", but includes courses with an end date in the past # Same as "marketable_course_runs_only", but includes courses with an end date in the past
course_runs = course_runs.marketable().enrollable()
# TODO: These queryset methods chain filter() and exclude() calls, causing
# prefetched results to be discarded.
course_runs = course.course_runs.exclude(hidden=True).marketable().enrollable()
if self.context.get('published_course_runs_only'): if self.context.get('published_course_runs_only'):
course_runs = course_runs.filter(status=CourseRunStatus.Published) # filter() isn't used to avoid discarding prefetched results.
course_runs = [
course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published
]
return CourseRunSerializer( return CourseRunSerializer(
course_runs, course_runs,
......
...@@ -29,7 +29,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -29,7 +29,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(20): with self.assertNumQueries(18):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.data, self.serialize_course(self.course))
...@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """ """ Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(13): with self.assertNumQueries(11):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), []) self.assertEqual(response.data.get('programs'), [])
...@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1' url += '?include_deleted_programs=1'
with self.assertNumQueries(23): with self.assertNumQueries(22):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual( self.assertEqual(
...@@ -154,7 +154,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -154,7 +154,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list') url = reverse('api:v1:course-list')
with self.assertNumQueries(26): with self.assertNumQueries(24):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual( self.assertListEqual(
...@@ -181,7 +181,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -181,7 +181,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses]) keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(41): with self.assertNumQueries(37):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
...@@ -64,7 +64,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -64,7 +64,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = self.create_program() program = self.create_program()
with self.assertNumQueries(42): with self.assertNumQueries(39):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program) assert response.data == self.serialize_program(program)
......
...@@ -5,7 +5,7 @@ from rest_framework.permissions import IsAuthenticated ...@@ -5,7 +5,7 @@ from rest_framework.permissions import IsAuthenticated
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import ProxiedPagination from course_discovery.apps.api.pagination import ProxiedPagination
from course_discovery.apps.api.v1.views import get_query_param, prefetch_related_objects_for_courses from course_discovery.apps.api.v1.views import get_query_param
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX
from course_discovery.apps.course_metadata.models import Course from course_discovery.apps.course_metadata.models import Course
...@@ -31,8 +31,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -31,8 +31,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
if q: if q:
queryset = Course.search(q) queryset = Course.search(q)
else: else:
queryset = super(CourseViewSet, self).get_queryset() queryset = self.get_serializer_class().prefetch_queryset()
queryset = prefetch_related_objects_for_courses(queryset)
return queryset.order_by(Lower('key')) return queryset.order_by(Lower('key'))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment