Commit bc3d7d01 by Renzo Lucioni

Program query optimizations

parent 3dd051a1
from django.test import TestCase from django.test import TestCase
from course_discovery.apps.api.fields import ImageField from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.tests.test_serializers import make_request
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory
class ImageFieldTests(TestCase): class ImageFieldTests(TestCase):
...@@ -13,3 +16,27 @@ class ImageFieldTests(TestCase): ...@@ -13,3 +16,27 @@ class ImageFieldTests(TestCase):
'width': None 'width': None
} }
self.assertEqual(ImageField().to_representation(value), expected) self.assertEqual(ImageField().to_representation(value), expected)
# pylint: disable=no-member
class StdImageSerializerFieldTests(TestCase):
def test_to_representation(self):
request = make_request()
# TODO Create test-only model to avoid unnecessary dependency on Program model.
program = ProgramFactory(banner_image=make_image_file('test.jpg'))
field = StdImageSerializerField()
field._context = {'request': request} # pylint: disable=protected-access
expected = {
size_key: {
'url': '{}{}'.format(
'http://testserver',
getattr(program.banner_image, size_key).url
),
'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height']
}
for size_key in program.banner_image.field.variations
}
self.assertDictEqual(field.to_representation(program.banner_image), expected)
...@@ -141,7 +141,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -141,7 +141,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2') CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2')
CourseRunFactory(enrollment_end=enrollment_end, course=self.course) CourseRunFactory(enrollment_end=enrollment_end, course=self.course)
with self.assertNumQueries(41): with self.assertNumQueries(40):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True))
......
...@@ -19,7 +19,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -19,7 +19,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(19): with self.assertNumQueries(18):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.data, self.serialize_course(self.course))
...@@ -28,7 +28,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -28,7 +28,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list') url = reverse('api:v1:course-list')
with self.assertNumQueries(25): with self.assertNumQueries(24):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual( self.assertListEqual(
...@@ -55,6 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -55,6 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses]) keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(38): with self.assertNumQueries(35):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
...@@ -40,15 +40,17 @@ class ProgramViewSetTests(APITestCase): ...@@ -40,15 +40,17 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = ProgramFactory() program = ProgramFactory()
self.assert_retrieve_success(program) with self.assertNumQueries(33):
self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self): def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """ """ Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory() course = CourseFactory()
program = ProgramFactory(courses=[course]) program = ProgramFactory(courses=[course])
self.assert_retrieve_success(program) with self.assertNumQueries(55):
self.assert_retrieve_success(program)
def assert_list_results(self, url, expected): def assert_list_results(self, url, expected, expected_query_count):
""" """
Asserts the results serialized/returned at the URL matches those that are expected. Asserts the results serialized/returned at the URL matches those that are expected.
Args: Args:
...@@ -62,7 +64,9 @@ class ProgramViewSetTests(APITestCase): ...@@ -62,7 +64,9 @@ class ProgramViewSetTests(APITestCase):
Returns: Returns:
None None
""" """
response = self.client.get(url) with self.assertNumQueries(expected_query_count):
response = self.client.get(url)
self.assertEqual( self.assertEqual(
response.data['results'], response.data['results'],
ProgramSerializer(expected, many=True, context={'request': self.request}).data ProgramSerializer(expected, many=True, context={'request': self.request}).data
...@@ -72,17 +76,17 @@ class ProgramViewSetTests(APITestCase): ...@@ -72,17 +76,17 @@ class ProgramViewSetTests(APITestCase):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3) expected = ProgramFactory.create_batch(3)
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected) self.assert_list_results(self.list_path, expected, 14)
def test_filter_by_type(self): def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """ """ Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo' program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name) program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program]) self.assert_list_results(url, [program], 14)
url = self.list_path + '?type=bar' url = self.list_path + '?type=bar'
self.assert_list_results(url, []) self.assert_list_results(url, [], 4)
def test_filter_by_uuids(self): def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """ """ Verify that the endpoint filters programs to those matching the provided UUIDs. """
...@@ -94,14 +98,14 @@ class ProgramViewSetTests(APITestCase): ...@@ -94,14 +98,14 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out. # Create a third program, which should be filtered out.
ProgramFactory() ProgramFactory()
self.assert_list_results(url, expected) self.assert_list_results(url, expected, 14)
@ddt.data( @ddt.data(
(ProgramStatus.Unpublished, False), (ProgramStatus.Unpublished, False, 4),
(ProgramStatus.Active, True), (ProgramStatus.Active, True, 14),
) )
@ddt.unpack @ddt.unpack
def test_filter_by_marketable(self, status, is_marketable): def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """ """ Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1' url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='') ProgramFactory(marketing_slug='')
...@@ -110,4 +114,4 @@ class ProgramViewSetTests(APITestCase): ...@@ -110,4 +114,4 @@ class ProgramViewSetTests(APITestCase):
expected = programs if is_marketable else [] expected = programs if is_marketable else []
self.assertEqual(list(Program.objects.marketable()), expected) self.assertEqual(list(Program.objects.marketable()), expected)
self.assert_list_results(url, expected) self.assert_list_results(url, expected, expected_query_count)
...@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset): ...@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset):
return queryset return queryset
def prefetch_related_objects_for_programs(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Program`.
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
course = serializers.PREFETCH_FIELDS['course'] + serializers.SELECT_RELATED_FIELDS['course']
course = ['courses__' + field for field in course]
queryset = queryset.prefetch_related(*course)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['program'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['program'])
return queryset
# pylint: disable=no-member # pylint: disable=no-member
class CatalogViewSet(viewsets.ModelViewSet): class CatalogViewSet(viewsets.ModelViewSet):
""" Catalog resource. """ """ Catalog resource. """
...@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
""" Program resource. """ """ Program resource. """
lookup_field = 'uuid' lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+' lookup_value_regex = '[0-9a-f-]+'
queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter filter_class = filters.ProgramFilter
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.serializer_class.prefetch_queryset()
def get_serializer_context(self, *args, **kwargs): def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs) context = super().get_serializer_context(*args, **kwargs)
context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0)) context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0))
......
import datetime import datetime
import itertools import itertools
import logging import logging
from collections import defaultdict
from urllib.parse import urljoin from urllib.parse import urljoin
from uuid import uuid4 from uuid import uuid4
...@@ -635,57 +636,66 @@ class Program(TimeStampedModel): ...@@ -635,57 +636,66 @@ class Program(TimeStampedModel):
@property @property
def course_runs(self): def course_runs(self):
"""
Warning! Only call this method after retrieving programs from `ProgramSerializer.prefetch_queryset()`.
Otherwise, this method will incur many, many queries when fetching related courses and course runs.
"""
excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()] excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()]
return CourseRun.objects.filter(course__programs=self).exclude(id__in=excluded_course_run_ids)
for course in self.courses.all():
for run in course.course_runs.all():
if run.id not in excluded_course_run_ids:
yield run
@property @property
def languages(self): def languages(self):
course_runs = self.course_runs.select_related('language') return set(course_run.language for course_run in self.course_runs if course_run.language is not None)
return set(course_run.language for course_run in course_runs if course_run.language is not None)
@property @property
def transcript_languages(self): def transcript_languages(self):
course_runs = self.course_runs.prefetch_related('transcript_languages') languages = [course_run.transcript_languages.all() for course_run in self.course_runs]
languages = [list(course_run.transcript_languages.all()) for course_run in course_runs]
languages = itertools.chain.from_iterable(languages) languages = itertools.chain.from_iterable(languages)
return set(languages) return set(languages)
@property @property
def subjects(self): def subjects(self):
courses = self.courses.prefetch_related('subjects') subjects = [course.subjects.all() for course in self.courses.all()]
subjects = [list(course.subjects.all()) for course in courses]
subjects = itertools.chain.from_iterable(subjects) subjects = itertools.chain.from_iterable(subjects)
return set(subjects) return set(subjects)
@property @property
def seats(self): def seats(self):
applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True) applicable_seat_types = set(seat_type.slug for seat_type in self.type.applicable_seat_types.all())
return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) \
.select_related('currency') for run in self.course_runs:
for seat in run.seats.all():
if seat.type in applicable_seat_types:
yield seat
@property @property
def seat_types(self): def seat_types(self):
return set(self.seats.values_list('type', flat=True)) return set(seat.type for seat in self.seats)
@property @property
def price_ranges(self): def price_ranges(self):
seats = self.seats.values('currency').annotate(models.Min('price'), models.Max('price')) currencies = defaultdict(list)
price_ranges = [] for seat in self.seats:
currencies[seat.currency].append(seat.price)
for seat in seats: price_ranges = []
for currency, prices in currencies.items():
price_ranges.append({ price_ranges.append({
'currency': seat['currency'], 'currency': currency.code,
'min': seat['price__min'], 'min': min(prices),
'max': seat['price__max'], 'max': max(prices),
}) })
return price_ranges return price_ranges
@property @property
def start(self): def start(self):
""" Start datetime, calculated by determining the earliest start datetime of all related course runs. """ """ Start datetime, calculated by determining the earliest start datetime of all related course runs. """
course_runs = self.course_runs if self.course_runs:
if course_runs:
start_dates = [course_run.start for course_run in self.course_runs if course_run.start] start_dates = [course_run.start for course_run in self.course_runs if course_run.start]
if start_dates: if start_dates:
...@@ -695,7 +705,7 @@ class Program(TimeStampedModel): ...@@ -695,7 +705,7 @@ class Program(TimeStampedModel):
@property @property
def staff(self): def staff(self):
staff = [list(course_run.staff.all()) for course_run in self.course_runs] staff = [course_run.staff.all() for course_run in self.course_runs]
staff = itertools.chain.from_iterable(staff) staff = itertools.chain.from_iterable(staff)
return set(staff) return set(staff)
......
...@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin): ...@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin):
return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()] return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()]
def prepare_staff_uuids(self, obj): def prepare_staff_uuids(self, obj):
return [str(staff.uuid) for course_run in obj.course_runs.all() for staff in course_run.staff.all()] return [str(staff.uuid) for course_run in obj.course_runs for staff in course_run.staff.all()]
def prepare_credit_backing_organizations(self, obj): def prepare_credit_backing_organizations(self, obj):
return self._prepare_organizations(obj.credit_backing_organizations.all()) return self._prepare_organizations(obj.credit_backing_organizations.all())
......
...@@ -100,7 +100,7 @@ class AdminTests(TestCase): ...@@ -100,7 +100,7 @@ class AdminTests(TestCase):
""" Verify that course selection page with posting the data. """ """ Verify that course selection page with posting the data. """
self.assertEqual(1, self.program.excluded_course_runs.all().count()) self.assertEqual(1, self.program.excluded_course_runs.all().count())
self.assertEqual(3, len(self.program.course_runs.all())) self.assertEqual(3, sum(1 for _ in self.program.course_runs))
params = { params = {
'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id], 'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id],
...@@ -114,7 +114,7 @@ class AdminTests(TestCase): ...@@ -114,7 +114,7 @@ class AdminTests(TestCase):
target_status_code=200 target_status_code=200
) )
self.assertEqual(2, self.program.excluded_course_runs.all().count()) self.assertEqual(2, self.program.excluded_course_runs.all().count())
self.assertEqual(2, len(self.program.course_runs.all())) self.assertEqual(2, sum(1 for _ in self.program.course_runs))
def test_page_with_post_without_course_run(self): def test_page_with_post_without_course_run(self):
""" Verify that course selection page without posting any selected excluded check run. """ """ Verify that course selection page without posting any selected excluded check run. """
...@@ -132,7 +132,7 @@ class AdminTests(TestCase): ...@@ -132,7 +132,7 @@ class AdminTests(TestCase):
target_status_code=200 target_status_code=200
) )
self.assertEqual(0, self.program.excluded_course_runs.all().count()) self.assertEqual(0, self.program.excluded_course_runs.all().count())
self.assertEqual(4, len(self.program.course_runs.all())) self.assertEqual(4, sum(1 for _ in self.program.course_runs))
response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,))) response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,)))
self.assertNotContains(response, '<input checked="checked")') self.assertNotContains(response, '<input checked="checked")')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment