Commit 5646091d by Clinton Blackburn

Pre-fetching data for course API responses

ECOM-5559 and ECOM-5440
parent bf1a3bc7
...@@ -141,9 +141,10 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -141,9 +141,10 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2') CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2')
CourseRunFactory(enrollment_end=enrollment_end, course=self.course) CourseRunFactory(enrollment_end=enrollment_end, course=self.course)
response = self.client.get(url) with self.assertNumQueries(40):
self.assertEqual(response.status_code, 200) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True)) self.assertEqual(response.status_code, 200)
self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True))
def test_contains(self): def test_contains(self):
""" Verify the endpoint returns a filtered list of courses contained in the catalog. """ """ Verify the endpoint returns a filtered list of courses contained in the catalog. """
...@@ -165,7 +166,9 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -165,7 +166,9 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
SeatFactory(type='credit', course_run=self.course_run, credit_provider='Hogwarts', credit_hours=4) SeatFactory(type='credit', course_run=self.course_run, credit_provider='Hogwarts', credit_hours=4)
url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id}) url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id})
response = self.client.get(url)
with self.assertNumQueries(24):
response = self.client.get(url)
course_run = self.serialize_catalog_flat_course_run(self.course_run) course_run = self.serialize_catalog_flat_course_run(self.course_run)
expected = ','.join([ expected = ','.join([
......
...@@ -19,20 +19,22 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -19,20 +19,22 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
response = self.client.get(url) with self.assertNumQueries(19):
self.assertEqual(response.status_code, 200) response = self.client.get(url)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course))
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list') url = reverse('api:v1:course-list')
response = self.client.get(url) with self.assertNumQueries(25):
self.assertEqual(response.status_code, 200) response = self.client.get(url)
self.assertListEqual( self.assertEqual(response.status_code, 200)
response.data['results'], self.assertListEqual(
self.serialize_course(Course.objects.all().order_by(Lower('key')), many=True) response.data['results'],
) self.serialize_course(Course.objects.all().order_by(Lower('key')), many=True)
)
def test_list_query(self): def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """ """ Verify the endpoint returns a filtered list of courses """
...@@ -42,8 +44,9 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -42,8 +44,9 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query) url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
response = self.client.get(url) with self.assertNumQueries(62):
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
def test_list_key_filter(self): def test_list_key_filter(self):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """ """ Verify the endpoint returns a list of courses filtered by the specified keys. """
...@@ -52,5 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -52,5 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses]) keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
response = self.client.get(url) with self.assertNumQueries(38):
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
...@@ -36,6 +36,45 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun, Part ...@@ -36,6 +36,45 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun, Part
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
PREFETCH_FIELDS = {
'course_run': (
'course__partner', 'course__level_type', 'course__programs', 'course__programs__type',
'course__programs__partner', 'seats', 'transcript_languages', 'seats__currency', 'staff',
'staff__position', 'staff__position__organization', 'language',
),
'course': (
'level_type', 'video', 'programs', 'course_runs', 'subjects', 'prerequisites', 'expected_learning_items',
'authoring_organizations', 'authoring_organizations__tags', 'authoring_organizations__partner',
'sponsoring_organizations', 'sponsoring_organizations__tags', 'sponsoring_organizations__partner',
),
}
def prefetch_related_objects_for_courses(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Course`.
Pre-fetching allows us to consolidate our database queries rather than run
thousands of queries as we serialize the data. For details, see the links below:
- https://docs.djangoproject.com/en/1.10/ref/models/querysets/#select-related
- https://docs.djangoproject.com/en/1.10/ref/models/querysets/#prefetch-related
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
# Prefetch the data for the related course runs
course_run_prefetch_fields = PREFETCH_FIELDS['course_run']
course_run_prefetch_fields = ['course_runs__' + field for field in course_run_prefetch_fields]
queryset = queryset.prefetch_related(*course_run_prefetch_fields)
queryset = queryset.select_related('level_type', 'video')
queryset = queryset.prefetch_related(*PREFETCH_FIELDS['course'])
return queryset
# pylint: disable=no-member # pylint: disable=no-member
class CatalogViewSet(viewsets.ModelViewSet): class CatalogViewSet(viewsets.ModelViewSet):
...@@ -110,6 +149,7 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -110,6 +149,7 @@ class CatalogViewSet(viewsets.ModelViewSet):
catalog = self.get_object() catalog = self.get_object()
queryset = catalog.courses().active() queryset = catalog.courses().active()
queryset = prefetch_related_objects_for_courses(queryset)
page = self.paginate_queryset(queryset) page = self.paginate_queryset(queryset)
serializer = serializers.CourseSerializerExcludingClosedRuns(page, many=True, context={'request': request}) serializer = serializers.CourseSerializerExcludingClosedRuns(page, many=True, context={'request': request})
...@@ -187,6 +227,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -187,6 +227,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Course.search(q) queryset = Course.search(q)
else: else:
queryset = super(CourseViewSet, self).get_queryset() queryset = super(CourseViewSet, self).get_queryset()
queryset = prefetch_related_objects_for_courses(queryset)
return queryset.order_by(Lower('key')) return queryset.order_by(Lower('key'))
...@@ -249,11 +290,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -249,11 +290,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
else: else:
queryset = super(CourseRunViewSet, self).get_queryset().filter(course__partner=partner) queryset = super(CourseRunViewSet, self).get_queryset().filter(course__partner=partner)
queryset = queryset.select_related('course', 'language', 'video') queryset = queryset.select_related('course', 'language', 'video')
queryset = queryset.prefetch_related( queryset = queryset.prefetch_related(*PREFETCH_FIELDS['course_run'])
'course__partner', 'course__level_type', 'course__programs', 'course__programs__type',
'course__programs__partner', 'seats', 'transcript_languages', 'seats__currency', 'staff',
'staff__position', 'staff__position__organization'
)
return queryset return queryset
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment