Commit 35dc00f7 by Renzo Lucioni Committed by GitHub

Merge pull request #347 from edx/renzo/query-optimization

Query optimization
parents 54feb857 bc3d7d01
......@@ -40,14 +40,14 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """
program = ProgramFactory()
with self.assertNumQueries(34):
with self.assertNumQueries(33):
self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory()
program = ProgramFactory(courses=[course])
with self.assertNumQueries(49):
with self.assertNumQueries(55):
self.assert_retrieve_success(program)
def assert_list_results(self, url, expected, expected_query_count):
......@@ -76,14 +76,14 @@ class ProgramViewSetTests(APITestCase):
""" Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3)
expected.reverse()
self.assert_list_results(self.list_path, expected, 40)
self.assert_list_results(self.list_path, expected, 14)
def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program], 18)
self.assert_list_results(url, [program], 14)
url = self.list_path + '?type=bar'
self.assert_list_results(url, [], 4)
......@@ -98,11 +98,11 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out.
ProgramFactory()
self.assert_list_results(url, expected, 29)
self.assert_list_results(url, expected, 14)
@ddt.data(
(ProgramStatus.Unpublished, False, 4),
(ProgramStatus.Active, True, 40),
(ProgramStatus.Active, True, 14),
)
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
......
......@@ -64,23 +64,6 @@ def prefetch_related_objects_for_courses(queryset):
return queryset
def prefetch_related_objects_for_programs(queryset):
"""
Pre-fetches the related objects that will be serialized with a `Program`.
Args:
queryset (QuerySet): original query
Returns:
QuerySet
"""
course = serializers.PREFETCH_FIELDS['course'] + serializers.SELECT_RELATED_FIELDS['course']
course = ['courses__' + field for field in course]
queryset = queryset.prefetch_related(*course)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['program'])
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['program'])
return queryset
# pylint: disable=no-member
class CatalogViewSet(viewsets.ModelViewSet):
""" Catalog resource. """
......@@ -393,12 +376,16 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
""" Program resource. """
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+'
queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.serializer_class.prefetch_queryset()
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0))
......
import datetime
import itertools
import logging
from collections import defaultdict
from urllib.parse import urljoin
from uuid import uuid4
......@@ -635,57 +636,66 @@ class Program(TimeStampedModel):
@property
def course_runs(self):
"""
Warning! Only call this method after retrieving programs from `ProgramSerializer.prefetch_queryset()`.
Otherwise, this method will incur many, many queries when fetching related courses and course runs.
"""
excluded_course_run_ids = [course_run.id for course_run in self.excluded_course_runs.all()]
return CourseRun.objects.filter(course__programs=self).exclude(id__in=excluded_course_run_ids)
for course in self.courses.all():
for run in course.course_runs.all():
if run.id not in excluded_course_run_ids:
yield run
@property
def languages(self):
course_runs = self.course_runs.select_related('language')
return set(course_run.language for course_run in course_runs if course_run.language is not None)
return set(course_run.language for course_run in self.course_runs if course_run.language is not None)
@property
def transcript_languages(self):
course_runs = self.course_runs.prefetch_related('transcript_languages')
languages = [list(course_run.transcript_languages.all()) for course_run in course_runs]
languages = [course_run.transcript_languages.all() for course_run in self.course_runs]
languages = itertools.chain.from_iterable(languages)
return set(languages)
@property
def subjects(self):
courses = self.courses.prefetch_related('subjects')
subjects = [list(course.subjects.all()) for course in courses]
subjects = [course.subjects.all() for course in self.courses.all()]
subjects = itertools.chain.from_iterable(subjects)
return set(subjects)
@property
def seats(self):
applicable_seat_types = self.type.applicable_seat_types.values_list('slug', flat=True)
return Seat.objects.filter(course_run__in=self.course_runs, type__in=applicable_seat_types) \
.select_related('currency')
applicable_seat_types = set(seat_type.slug for seat_type in self.type.applicable_seat_types.all())
for run in self.course_runs:
for seat in run.seats.all():
if seat.type in applicable_seat_types:
yield seat
@property
def seat_types(self):
return set(self.seats.values_list('type', flat=True))
return set(seat.type for seat in self.seats)
@property
def price_ranges(self):
seats = self.seats.values('currency').annotate(models.Min('price'), models.Max('price'))
price_ranges = []
currencies = defaultdict(list)
for seat in self.seats:
currencies[seat.currency].append(seat.price)
for seat in seats:
price_ranges = []
for currency, prices in currencies.items():
price_ranges.append({
'currency': seat['currency'],
'min': seat['price__min'],
'max': seat['price__max'],
'currency': currency.code,
'min': min(prices),
'max': max(prices),
})
return price_ranges
@property
def start(self):
""" Start datetime, calculated by determining the earliest start datetime of all related course runs. """
course_runs = self.course_runs
if course_runs:
if self.course_runs:
start_dates = [course_run.start for course_run in self.course_runs if course_run.start]
if start_dates:
......@@ -695,7 +705,7 @@ class Program(TimeStampedModel):
@property
def staff(self):
staff = [list(course_run.staff.all()) for course_run in self.course_runs]
staff = [course_run.staff.all() for course_run in self.course_runs]
staff = itertools.chain.from_iterable(staff)
return set(staff)
......
......@@ -201,7 +201,7 @@ class ProgramIndex(BaseIndex, indexes.Indexable, OrganizationsMixin):
return [str(subject.uuid) for course in obj.courses.all() for subject in course.subjects.all()]
def prepare_staff_uuids(self, obj):
return [str(staff.uuid) for course_run in obj.course_runs.all() for staff in course_run.staff.all()]
return [str(staff.uuid) for course_run in obj.course_runs for staff in course_run.staff.all()]
def prepare_credit_backing_organizations(self, obj):
return self._prepare_organizations(obj.credit_backing_organizations.all())
......
......@@ -100,7 +100,7 @@ class AdminTests(TestCase):
""" Verify that course selection page with posting the data. """
self.assertEqual(1, self.program.excluded_course_runs.all().count())
self.assertEqual(3, len(self.program.course_runs.all()))
self.assertEqual(3, sum(1 for _ in self.program.course_runs))
params = {
'excluded_course_runs': [self.excluded_course_run.id, self.course_runs[0].id],
......@@ -114,7 +114,7 @@ class AdminTests(TestCase):
target_status_code=200
)
self.assertEqual(2, self.program.excluded_course_runs.all().count())
self.assertEqual(2, len(self.program.course_runs.all()))
self.assertEqual(2, sum(1 for _ in self.program.course_runs))
def test_page_with_post_without_course_run(self):
""" Verify that course selection page without posting any selected excluded check run. """
......@@ -132,7 +132,7 @@ class AdminTests(TestCase):
target_status_code=200
)
self.assertEqual(0, self.program.excluded_course_runs.all().count())
self.assertEqual(4, len(self.program.course_runs.all()))
self.assertEqual(4, sum(1 for _ in self.program.course_runs))
response = self.client.get(reverse('admin_metadata:update_course_runs', args=(self.program.id,)))
self.assertNotContains(response, '<input checked="checked")')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment