Commit bc3d7d01 by Renzo Lucioni

Program query optimizations

parent 3dd051a1
from django.test import TestCase
from course_discovery.apps.api.fields import ImageField
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.tests.test_serializers import make_request
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory
class ImageFieldTests(TestCase):
......@@ -13,3 +16,27 @@ class ImageFieldTests(TestCase):
'width': None
}
self.assertEqual(ImageField().to_representation(value), expected)
# pylint: disable=no-member
class StdImageSerializerFieldTests(TestCase):
def test_to_representation(self):
request = make_request()
# TODO Create test-only model to avoid unnecessary dependency on Program model.
program = ProgramFactory(banner_image=make_image_file('test.jpg'))
field = StdImageSerializerField()
field._context = {'request': request} # pylint: disable=protected-access
expected = {
size_key: {
'url': '{}{}'.format(
'http://testserver',
getattr(program.banner_image, size_key).url
),
'width': program.banner_image.field.variations[size_key]['width'],
'height': program.banner_image.field.variations[size_key]['height']
}
for size_key in program.banner_image.field.variations
}
self.assertDictEqual(field.to_representation(program.banner_image), expected)
......@@ -141,7 +141,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
CourseRunFactory(enrollment_end=enrollment_end, course__title='ABC Test Course 2')
CourseRunFactory(enrollment_end=enrollment_end, course=self.course)
with self.assertNumQueries(41):
with self.assertNumQueries(40):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(response.data['results'], self.serialize_catalog_course(courses, many=True))
......
......@@ -19,7 +19,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(19):
with self.assertNumQueries(18):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course))
......@@ -28,7 +28,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
with self.assertNumQueries(25):
with self.assertNumQueries(24):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(
......@@ -55,6 +55,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(38):
with self.assertNumQueries(35):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......@@ -40,15 +40,17 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """
program = ProgramFactory()
self.assert_retrieve_success(program)
with self.assertNumQueries(33):
self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory()
program = ProgramFactory(courses=[course])
self.assert_retrieve_success(program)
with self.assertNumQueries(55):
self.assert_retrieve_success(program)
def assert_list_results(self, url, expected):
def assert_list_results(self, url, expected, expected_query_count):
"""
Asserts the results serialized/returned at the URL matches those that are expected.
Args:
......@@ -62,7 +64,9 @@ class ProgramViewSetTests(APITestCase):
Returns:
None
"""
response = self.client.get(url)
with self.assertNumQueries(expected_query_count):
response = self.client.get(url)
self.assertEqual(
response.data['results'],
ProgramSerializer(expected, many=True, context={'request': self.request}).data
......@@ -72,17 +76,17 @@ class ProgramViewSetTests(APITestCase):
""" Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3)
expected.reverse()
self.assert_list_results(self.list_path, expected)
self.assert_list_results(self.list_path, expected, 14)
def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program])
self.assert_list_results(url, [program], 14)
url = self.list_path + '?type=bar'
self.assert_list_results(url, [])
self.assert_list_results(url, [], 4)
def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
......@@ -94,14 +98,14 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out.
ProgramFactory()
self.assert_list_results(url, expected)
self.assert_list_results(url, expected, 14)
@ddt.data(
(ProgramStatus.Unpublished, False),
(ProgramStatus.Active, True),
(ProgramStatus.Unpublished, False, 4),
(ProgramStatus.Active, True, 14),
)
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable):
def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='')
......@@ -110,4 +114,4 @@ class ProgramViewSetTests(APITestCase):
expected = programs if is_marketable else []
self.assertEqual(list(Program.objects.marketable()), expected)
self.assert_list_results(url, expected)
self.assert_list_results(url, expected, expected_query_count)
......@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset):
return queryset
def prefetch_related_objects_for_programs(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Program`.
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
course = serializers.PREFETCH_FIELDS['course'] + serializers.SELECT_RELATED_FIELDS['course']
course = ['courses__' + field for field in course]
queryset = queryset.prefetch_related(*course)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['program'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['program'])
return queryset
# pylint: disable=no-member
class CatalogViewSet(viewsets.ModelViewSet):
""" Catalog resource. """
......@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
""" Program resource. """
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+'
queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.serializer_class.prefetch_queryset()
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0))
......
import datetime
import itertools
import logging
from collections import defaultdict
from urllib.parse import urljoin
from uuid import uuid4
......@@ -635,57 +636,66 @@ class Program(TimeStampedModel):
@property
def course_runs(self):
"""
Warning! Only call this method after retrieving programs from `ProgramSerializer.prefetch_queryset()`.
Otherwise, this method will incur many, many queries when fetching related courses and course runs.
"""
excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()]
return CourseRun.objects.filter(course__programs=self).exclude(id__in=excluded_course_run_ids)
for course in self.courses.all():
for run in course.course_runs.all():
if run.id not in excluded_course_run_ids:
yield run
@property
def languages(self):
course_runs = self.course_runs.select_related('language')
return set(course_run.language for course_run in course_runs if course_run.language is not None)
return set(course_run.language for course_run in self.course_runs if course_run.language is not None)
@property
def transcript_languages(self):
course_runs = self.course_runs.prefetch_related('transcript_languages')
languages = [list(course_run.transcript_languages.all()) for course_run in course_runs]
languages = [course_run.transcript_languages.all() for course_run in self.course_runs]
languages = itertools.chain.from_iterable(languages)
return set(languages)
@property
def subjects(self):
courses = self.courses.prefetch_related('subjects')
subjects = [list(course.subjects.all()) for course in courses]
subjects = [course.subjects.all() for course in self.courses.all()]
subjects = itertools.chain.from_iterable(subjects)
return set(subjects)
@property
def seats(self):
applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True)
return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) \
.select_related('currency')
applicable_seat_types = set(seat_type.slug for seat_type in self.type.applicable_seat_types.all())
for run in self.course_runs:
for seat in run.seats.all():
if seat.type in applicable_seat_types:
yield seat
@property
def seat_types(self):
return set(self.seats.values_list('type', flat=True))
return set(seat.type for seat in self.seats)
@property
def price_ranges(self):
seats = self.seats.values('currency').annotate(models.Min('price'), models.Max('price'))
price_ranges = []
currencies = defaultdict(list)
for seat in self.seats:
currencies[seat.currency].append(seat.price)
for seat in seats:
price_ranges = []
for currency, prices in currencies.items():
price_ranges.append({
'currency': seat['currency'],
'min': seat['price__min'],
'max': seat['price__max'],
'currency': currency.code,
'min': min(prices),
'max': max(prices),
})
return price_ranges
@property
def start(self):
""" Start datetime, calculated by determining the earliest start datetime of all related course runs. """
course_runs = self.course_runs
if course_runs:
if self.course_runs:
start_dates = [course_run.start for course_run in self.course_runs if course_run.start]
if start_dates:
......@@ -695,7 +705,7 @@ class Program(TimeStampedModel):
@property
def staff(self):
staff = [list(course_run.staff.all()) for course_run in self.course_runs]
staff = [course_run.staff.all() for course_run in self.course_runs]
staff = itertools.chain.from_iterable(staff)
return set(staff)
......
......@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin):
return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()]
def prepare_staff_uuids(self, obj):
return [str(staff.uuid) for course_run in obj.course_runs.all() for staff in course_run.staff.all()]
return [str(staff.uuid) for course_run in obj.course_runs for staff in course_run.staff.all()]
def prepare_credit_backing_organizations(self, obj):
return self._prepare_organizations(obj.credit_backing_organizations.all())
......
......@@ -100,7 +100,7 @@ class AdminTests(TestCase):
""" Verify that course selection page with posting the data. """
self.assertEqual(1, self.program.excluded_course_runs.all().count())
self.assertEqual(3, len(self.program.course_runs.all()))
self.assertEqual(3, sum(1 for _ in self.program.course_runs))
params = {
'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id],
......@@ -114,7 +114,7 @@ class AdminTests(TestCase):
target_status_code=200
)
self.assertEqual(2, self.program.excluded_course_runs.all().count())
self.assertEqual(2, len(self.program.course_runs.all()))
self.assertEqual(2, sum(1 for _ in self.program.course_runs))
def test_page_with_post_without_course_run(self):
""" Verify that course selection page without posting any selected excluded check run. """
......@@ -132,7 +132,7 @@ class AdminTests(TestCase):
target_status_code=200
)
self.assertEqual(0, self.program.excluded_course_runs.all().count())
self.assertEqual(4, len(self.program.course_runs.all()))
self.assertEqual(4, sum(1 for _ in self.program.course_runs))
response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,)))
self.assertNotContains(response, '<input checked="checked")')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment