Commit 35dc00f7 by Renzo Lucioni Committed by GitHub

Merge pull request #347 from edx/renzo/query-optimization

Query optimization
parents 54feb857 bc3d7d01
...@@ -40,14 +40,14 @@ class ProgramViewSetTests(APITestCase): ...@@ -40,14 +40,14 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = ProgramFactory() program = ProgramFactory()
with self.assertNumQueries(34): with self.assertNumQueries(33):
self.assert_retrieve_success(program) self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self): def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """ """ Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory() course = CourseFactory()
program = ProgramFactory(courses=[course]) program = ProgramFactory(courses=[course])
with self.assertNumQueries(49): with self.assertNumQueries(55):
self.assert_retrieve_success(program) self.assert_retrieve_success(program)
def assert_list_results(self, url, expected, expected_query_count): def assert_list_results(self, url, expected, expected_query_count):
...@@ -76,14 +76,14 @@ class ProgramViewSetTests(APITestCase): ...@@ -76,14 +76,14 @@ class ProgramViewSetTests(APITestCase):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3) expected = ProgramFactory.create_batch(3)
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected, 40) self.assert_list_results(self.list_path, expected, 14)
def test_filter_by_type(self): def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """ """ Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo' program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name) program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program], 18) self.assert_list_results(url, [program], 14)
url = self.list_path + '?type=bar' url = self.list_path + '?type=bar'
self.assert_list_results(url, [], 4) self.assert_list_results(url, [], 4)
...@@ -98,11 +98,11 @@ class ProgramViewSetTests(APITestCase): ...@@ -98,11 +98,11 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out. # Create a third program, which should be filtered out.
ProgramFactory() ProgramFactory()
self.assert_list_results(url, expected, 29) self.assert_list_results(url, expected, 14)
@ddt.data( @ddt.data(
(ProgramStatus.Unpublished, False, 4), (ProgramStatus.Unpublished, False, 4),
(ProgramStatus.Active, True, 40), (ProgramStatus.Active, True, 14),
) )
@ddt.unpack @ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count): def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
......
...@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset): ...@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset):
return queryset return queryset
def prefetch_related_objects_for_programs(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Program`.
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
course = serializers.PREFETCH_FIELDS['course'] + serializers.SELECT_RELATED_FIELDS['course']
course = ['courses__' + field for field in course]
queryset = queryset.prefetch_related(*course)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['program'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['program'])
return queryset
# pylint: disable=no-member # pylint: disable=no-member
class CatalogViewSet(viewsets.ModelViewSet): class CatalogViewSet(viewsets.ModelViewSet):
""" Catalog resource. """ """ Catalog resource. """
...@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
""" Program resource. """ """ Program resource. """
lookup_field = 'uuid' lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+' lookup_value_regex = '[0-9a-f-]+'
queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter filter_class = filters.ProgramFilter
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.serializer_class.prefetch_queryset()
def get_serializer_context(self, *args, **kwargs): def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs) context = super().get_serializer_context(*args, **kwargs)
context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0)) context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0))
......
import datetime import datetime
import itertools import itertools
import logging import logging
from collections import defaultdict
from urllib.parse import urljoin from urllib.parse import urljoin
from uuid import uuid4 from uuid import uuid4
...@@ -635,57 +636,66 @@ class Program(TimeStampedModel): ...@@ -635,57 +636,66 @@ class Program(TimeStampedModel):
@property @property
def course_runs(self): def course_runs(self):
"""
Warning! Only call this method after retrieving programs from `ProgramSerializer.prefetch_queryset()`.
Otherwise, this method will incur many, many queries when fetching related courses and course runs.
"""
excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()] excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()]
return CourseRun.objects.filter(course__programs=self).exclude(id__in=excluded_course_run_ids)
for course in self.courses.all():
for run in course.course_runs.all():
if run.id not in excluded_course_run_ids:
yield run
@property @property
def languages(self): def languages(self):
course_runs = self.course_runs.select_related('language') return set(course_run.language for course_run in self.course_runs if course_run.language is not None)
return set(course_run.language for course_run in course_runs if course_run.language is not None)
@property @property
def transcript_languages(self): def transcript_languages(self):
course_runs = self.course_runs.prefetch_related('transcript_languages') languages = [course_run.transcript_languages.all() for course_run in self.course_runs]
languages = [list(course_run.transcript_languages.all()) for course_run in course_runs]
languages = itertools.chain.from_iterable(languages) languages = itertools.chain.from_iterable(languages)
return set(languages) return set(languages)
@property @property
def subjects(self): def subjects(self):
courses = self.courses.prefetch_related('subjects') subjects = [course.subjects.all() for course in self.courses.all()]
subjects = [list(course.subjects.all()) for course in courses]
subjects = itertools.chain.from_iterable(subjects) subjects = itertools.chain.from_iterable(subjects)
return set(subjects) return set(subjects)
@property @property
def seats(self): def seats(self):
applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True) applicable_seat_types = set(seat_type.slug for seat_type in self.type.applicable_seat_types.all())
return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) \
.select_related('currency') for run in self.course_runs:
for seat in run.seats.all():
if seat.type in applicable_seat_types:
yield seat
@property @property
def seat_types(self): def seat_types(self):
return set(self.seats.values_list('type', flat=True)) return set(seat.type for seat in self.seats)
@property @property
def price_ranges(self): def price_ranges(self):
seats = self.seats.values('currency').annotate(models.Min('price'), models.Max('price')) currencies = defaultdict(list)
price_ranges = [] for seat in self.seats:
currencies[seat.currency].append(seat.price)
for seat in seats: price_ranges = []
for currency, prices in currencies.items():
price_ranges.append({ price_ranges.append({
'currency': seat['currency'], 'currency': currency.code,
'min': seat['price__min'], 'min': min(prices),
'max': seat['price__max'], 'max': max(prices),
}) })
return price_ranges return price_ranges
@property @property
def start(self): def start(self):
""" Start datetime, calculated by determining the earliest start datetime of all related course runs. """ """ Start datetime, calculated by determining the earliest start datetime of all related course runs. """
course_runs = self.course_runs if self.course_runs:
if course_runs:
start_dates = [course_run.start for course_run in self.course_runs if course_run.start] start_dates = [course_run.start for course_run in self.course_runs if course_run.start]
if start_dates: if start_dates:
...@@ -695,7 +705,7 @@ class Program(TimeStampedModel): ...@@ -695,7 +705,7 @@ class Program(TimeStampedModel):
@property @property
def staff(self): def staff(self):
staff = [list(course_run.staff.all()) for course_run in self.course_runs] staff = [course_run.staff.all() for course_run in self.course_runs]
staff = itertools.chain.from_iterable(staff) staff = itertools.chain.from_iterable(staff)
return set(staff) return set(staff)
......
...@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin): ...@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin):
return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()] return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()]
def prepare_staff_uuids(self, obj): def prepare_staff_uuids(self, obj):
return [str(staff.uuid) for course_run in obj.course_runs.all() for staff in course_run.staff.all()] return [str(staff.uuid) for course_run in obj.course_runs for staff in course_run.staff.all()]
def prepare_credit_backing_organizations(self, obj): def prepare_credit_backing_organizations(self, obj):
return self._prepare_organizations(obj.credit_backing_organizations.all()) return self._prepare_organizations(obj.credit_backing_organizations.all())
......
...@@ -100,7 +100,7 @@ class AdminTests(TestCase): ...@@ -100,7 +100,7 @@ class AdminTests(TestCase):
""" Verify that course selection page with posting the data. """ """ Verify that course selection page with posting the data. """
self.assertEqual(1, self.program.excluded_course_runs.all().count()) self.assertEqual(1, self.program.excluded_course_runs.all().count())
self.assertEqual(3, len(self.program.course_runs.all())) self.assertEqual(3, sum(1 for _ in self.program.course_runs))
params = { params = {
'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id], 'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id],
...@@ -114,7 +114,7 @@ class AdminTests(TestCase): ...@@ -114,7 +114,7 @@ class AdminTests(TestCase):
target_status_code=200 target_status_code=200
) )
self.assertEqual(2, self.program.excluded_course_runs.all().count()) self.assertEqual(2, self.program.excluded_course_runs.all().count())
self.assertEqual(2, len(self.program.course_runs.all())) self.assertEqual(2, sum(1 for _ in self.program.course_runs))
def test_page_with_post_without_course_run(self): def test_page_with_post_without_course_run(self):
""" Verify that course selection page without posting any selected excluded check run. """ """ Verify that course selection page without posting any selected excluded check run. """
...@@ -132,7 +132,7 @@ class AdminTests(TestCase): ...@@ -132,7 +132,7 @@ class AdminTests(TestCase):
target_status_code=200 target_status_code=200
) )
self.assertEqual(0, self.program.excluded_course_runs.all().count()) self.assertEqual(0, self.program.excluded_course_runs.all().count())
self.assertEqual(4, len(self.program.course_runs.all())) self.assertEqual(4, sum(1 for _ in self.program.course_runs))
response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,))) response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,)))
self.assertNotContains(response, '<input checked="checked")') self.assertNotContains(response, '<input checked="checked")')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment