Commit 2e556a19 by Renzo Lucioni

Prevent course API from discarding prefetched data

filter() and exclude() calls made by the serializer after prefetching were causing prefetched data to be discarded. Moving that filtering ahead of the prefetch prevents prefetched data from being discarded and saves many duplicate queries.
parent 3e42af4f
...@@ -449,6 +449,12 @@ class CourseRunWithProgramsSerializer(CourseRunSerializer): ...@@ -449,6 +449,12 @@ class CourseRunWithProgramsSerializer(CourseRunSerializer):
"""A ``CourseRunSerializer`` which includes programs derived from parent course.""" """A ``CourseRunSerializer`` which includes programs derived from parent course."""
programs = serializers.SerializerMethodField() programs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None):
queryset = super().prefetch_queryset(queryset=queryset)
return queryset.prefetch_related('course__programs')
def get_programs(self, obj): def get_programs(self, obj):
programs = [] programs = []
# Filter out non-deleted programs which this course_run is part of the program course_run exclusion # Filter out non-deleted programs which this course_run is part of the program course_run exclusion
...@@ -482,14 +488,14 @@ class MinimalCourseSerializer(TimestampModelSerializer): ...@@ -482,14 +488,14 @@ class MinimalCourseSerializer(TimestampModelSerializer):
image = ImageField(read_only=True, source='card_image_url') image = ImageField(read_only=True, source='card_image_url')
@classmethod @classmethod
def prefetch_queryset(cls, queryset=None): def prefetch_queryset(cls, queryset=None, course_runs=None):
# Explicitly check for None to avoid returning all Courses when the # Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty. # queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.all() queryset = queryset if queryset is not None else Course.objects.all()
return queryset.select_related('partner').prefetch_related( return queryset.select_related('partner').prefetch_related(
'authoring_organizations', 'authoring_organizations',
Prefetch('course_runs', queryset=MinimalCourseRunSerializer.prefetch_queryset()), Prefetch('course_runs', queryset=MinimalCourseRunSerializer.prefetch_queryset(queryset=course_runs)),
) )
class Meta: class Meta:
...@@ -510,7 +516,7 @@ class CourseSerializer(MinimalCourseSerializer): ...@@ -510,7 +516,7 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url = serializers.SerializerMethodField() marketing_url = serializers.SerializerMethodField()
@classmethod @classmethod
def prefetch_queryset(cls, queryset=None): def prefetch_queryset(cls, queryset=None, course_runs=None):
# Explicitly check for None to avoid returning all Courses when the # Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty. # queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.all() queryset = queryset if queryset is not None else Course.objects.all()
...@@ -519,7 +525,7 @@ class CourseSerializer(MinimalCourseSerializer): ...@@ -519,7 +525,7 @@ class CourseSerializer(MinimalCourseSerializer):
'expected_learning_items', 'expected_learning_items',
'prerequisites', 'prerequisites',
'subjects', 'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset()), Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
) )
...@@ -544,36 +550,26 @@ class CourseWithProgramsSerializer(CourseSerializer): ...@@ -544,36 +550,26 @@ class CourseWithProgramsSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField() course_runs = serializers.SerializerMethodField()
programs = serializers.SerializerMethodField() programs = serializers.SerializerMethodField()
def get_course_runs(self, course): @classmethod
# exclude() isn't used to avoid discarding prefetched results. def prefetch_queryset(cls, queryset=None, course_runs=None):
course_runs = [course_run for course_run in course.course_runs.all() if not course_run.hidden] """
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
if self.context.get('marketable_course_runs_only'): filtered CourseRun queryset.
# A client requesting marketable_course_runs_only should only receive course runs """
# that are published, have seats, and can still be enrolled in. All other course runs queryset = queryset if queryset is not None else Course.objects.all()
# should be excluded. As an unfortunate side-effect of the way we've marketed course
# runs in the past - a course run could be marketed despite enrollment in that run being
# closed - achieving this requires applying both the marketable and active filters.
# TODO: These queryset methods chain filter() and exclude() calls, causing
# prefetched results to be discarded.
course_runs = course.course_runs.exclude(hidden=True).marketable().active()
if self.context.get('marketable_enrollable_course_runs_with_archived'):
# Same as "marketable_course_runs_only", but includes courses with an end date in the past
# TODO: These queryset methods chain filter() and exclude() calls, causing
# prefetched results to be discarded.
course_runs = course.course_runs.exclude(hidden=True).marketable().enrollable()
if self.context.get('published_course_runs_only'): return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
# filter() isn't used to avoid discarding prefetched results. 'expected_learning_items',
course_runs = [ 'prerequisites',
course_run for course_run in course_runs if course_run.status == CourseRunStatus.Published 'subjects',
] Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
def get_course_runs(self, course):
return CourseRunSerializer( return CourseRunSerializer(
course_runs, course.course_runs,
many=True, many=True,
context={ context={
'request': self.context.get('request'), 'request': self.context.get('request'),
...@@ -603,33 +599,25 @@ class CatalogCourseSerializer(CourseSerializer): ...@@ -603,33 +599,25 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField() course_runs = serializers.SerializerMethodField()
@classmethod @classmethod
def prefetch_queryset(cls, queryset=None): def prefetch_queryset(cls, queryset=None, course_runs=None):
""" """
Similar to the CourseSerializer's prefetch_queryset, but prefetches a Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset. filtered CourseRun queryset.
""" """
queryset = queryset if queryset is not None else Course.objects.all() queryset = queryset if queryset is not None else Course.objects.all()
available_course_runs = CourseRun.objects.active().enrollable().marketable()
return queryset.select_related('level_type', 'video', 'partner').prefetch_related( return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items', 'expected_learning_items',
'prerequisites', 'prerequisites',
'subjects', 'subjects',
Prefetch( Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
'course_runs',
queryset=CourseRunSerializer.prefetch_queryset(queryset=available_course_runs),
# Using to_attr is recommended when filtering down a prefetch
# result as it is less ambiguous than storing a filtered result
# in the related manager’s cache and accessing it via all().
to_attr='available_course_runs'
),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
) )
def get_course_runs(self, course): def get_course_runs(self, course):
return CourseRunSerializer( return CourseRunSerializer(
course.available_course_runs, course.course_runs,
many=True, many=True,
context=self.context context=self.context
).data ).data
......
...@@ -201,99 +201,6 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests): ...@@ -201,99 +201,6 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
) )
self.assertEqual(serializer.data, self.get_expected_data(self.course, self.request)) self.assertEqual(serializer.data, self.get_expected_data(self.course, self.request))
@ddt.data(0, 1)
def test_marketable_course_runs_only(self, marketable_course_runs_only):
"""
Verify that the marketable_course_runs_only option is respected, restricting returned
course runs to those that are published, have seats, and can still be enrolled in.
"""
enrollable_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=10),
enrollment_start=None,
enrollment_end=None,
course=self.course
)
SeatFactory(course_run=enrollable_course_run)
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course)
SeatFactory(course_run=unpublished_course_run)
CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
closed_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=10),
course=self.course
)
SeatFactory(course_run=closed_course_run)
serializer = self.serializer_class(
self.course,
context={'request': self.request, 'marketable_course_runs_only': marketable_course_runs_only}
)
self.assertEqual(
len(serializer.data['course_runs']),
1 if marketable_course_runs_only else 4
)
def test_marketable_enrollable_course_runs_with_archived(self):
"""
Verify that the marketable_enrollable_course_runs_with_archived option is respected, restricting returned
course runs to those that are published, have seats, and can still be enrolled in
(including courses with an end date in the past.)
"""
enrollable_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=10),
enrollment_start=None,
enrollment_end=None,
course=self.course
)
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course)
CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
archived_course_run = CourseRunFactory(
status=CourseRunStatus.Published,
end=datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=10),
enrollment_start=None,
enrollment_end=None,
course=self.course
)
SeatFactory(course_run=unpublished_course_run)
SeatFactory(course_run=enrollable_course_run)
SeatFactory(course_run=archived_course_run)
context = {
'request': self.request,
'marketable_enrollable_course_runs_with_archived': 1
}
course_serializer = self.serializer_class(
self.course,
context=context
)
course_run_keys = [course_run['key'] for course_run in course_serializer.data['course_runs']]
# order doesn't matter
assert sorted(course_run_keys) == sorted([enrollable_course_run.key, archived_course_run.key])
@ddt.data(0, 1)
def test_published_course_runs_only(self, published_course_runs_only):
"""
Test that the published_course_runs_only flag hides unpublished course runs
"""
unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished)
published_course_run = CourseRunFactory(status=CourseRunStatus.Published)
self.course.course_runs.add(unpublished_course_run, published_course_run)
serializer = self.serializer_class(
self.course,
context={'request': self.request, 'published_course_runs_only': published_course_runs_only}
)
self.assertEqual(len(serializer.data['course_runs']), 2 - published_course_runs_only)
class MinimalCourseRunSerializerTests(TestCase): class MinimalCourseRunSerializerTests(TestCase):
serializer_class = MinimalCourseRunSerializer serializer_class = MinimalCourseRunSerializer
......
...@@ -170,15 +170,16 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -170,15 +170,16 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# This run has no seats, but we still expect its parent course # This run has no seats, but we still expect its parent course
# to be included. # to be included.
CourseRunFactory(course=course) filtered_course_run = CourseRunFactory(course=course)
with self.assertNumQueries(18): with self.assertNumQueries(18):
response = self.client.get(url) response = self.client.get(url)
# Prefetched results are assigned to a custom attribute.
course.available_course_runs = [course_run]
assert response.status_code == 200 assert response.status_code == 200
# Emulate prefetching behavior.
filtered_course_run.delete()
assert response.data['results'] == self.serialize_catalog_course([course], many=True) assert response.data['results'] == self.serialize_catalog_course([course], many=True)
# Any course appearing in the response must have at least one serialized run. # Any course appearing in the response must have at least one serialized run.
......
...@@ -51,7 +51,7 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -51,7 +51,7 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key}) url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key})
with self.assertNumQueries(13): with self.assertNumQueries(10):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
assert response.data.get('programs') == [] assert response.data.get('programs') == []
...@@ -78,7 +78,7 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -78,7 +78,7 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key}) url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key})
with self.assertNumQueries(13): with self.assertNumQueries(11):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
assert response.data.get('programs') == [] assert response.data.get('programs') == []
...@@ -130,10 +130,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -130,10 +130,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
assert response.status_code == 403 assert response.status_code == 403
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all catalogs. """ """ Verify the endpoint returns a list of all course runs. """
url = reverse('api:v1:course_run-list') url = reverse('api:v1:course_run-list')
with self.assertNumQueries(12): with self.assertNumQueries(13):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
...@@ -143,10 +143,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -143,10 +143,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
) )
def test_list_sorted_by_course_start_date(self): def test_list_sorted_by_course_start_date(self):
""" Verify the endpoint returns a list of all catalogs sorted by course start date. """ """ Verify the endpoint returns a list of all course runs sorted by start date. """
url = '{root}?ordering=start'.format(root=reverse('api:v1:course_run-list')) url = '{root}?ordering=start'.format(root=reverse('api:v1:course_run-list'))
with self.assertNumQueries(12): with self.assertNumQueries(13):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
......
...@@ -80,7 +80,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -80,7 +80,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
SeatFactory(course_run=unpublished_course_run) SeatFactory(course_run=unpublished_course_run)
# Published course run with no seats. # Published course run with no seats.
CourseRunFactory(status=CourseRunStatus.Published, course=self.course) no_seats_course_run = CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
# Published course run with a seat and an end date in the past. # Published course run with a seat and an end date in the past.
closed_course_run = CourseRunFactory( closed_course_run = CourseRunFactory(
...@@ -94,14 +94,14 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -94,14 +94,14 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
url = '{}?marketable_course_runs_only={}'.format(url, marketable_course_runs_only) url = '{}?marketable_course_runs_only={}'.format(url, marketable_course_runs_only)
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) assert response.status_code == 200
self.assertEqual(
response.data, if marketable_course_runs_only:
self.serialize_course( # Emulate prefetching behavior.
self.course, for course_run in (unpublished_course_run, no_seats_course_run, closed_course_run):
extra_context={'marketable_course_runs_only': marketable_course_runs_only} course_run.delete()
)
) assert response.data == self.serialize_course(self.course)
@ddt.data(1, 0) @ddt.data(1, 0)
def test_marketable_enrollable_course_runs_with_archived(self, marketable_enrollable_course_runs_with_archived): def test_marketable_enrollable_course_runs_with_archived(self, marketable_enrollable_course_runs_with_archived):
...@@ -111,13 +111,17 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -111,13 +111,17 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
past = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=2) past = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=2)
future = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2) future = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
CourseRunFactory(enrollment_start=None, enrollment_end=future, course=self.course) course_run = CourseRunFactory(enrollment_start=None, enrollment_end=future, course=self.course)
CourseRunFactory(enrollment_start=None, enrollment_end=None, course=self.course) SeatFactory(course_run=course_run)
CourseRunFactory(
enrollment_start=past, enrollment_end=future, course=self.course filtered_course_runs = [
) CourseRunFactory(enrollment_start=None, enrollment_end=None, course=self.course),
CourseRunFactory(enrollment_start=future, course=self.course) CourseRunFactory(
CourseRunFactory(enrollment_end=past, course=self.course) enrollment_start=past, enrollment_end=future, course=self.course
),
CourseRunFactory(enrollment_start=future, course=self.course),
CourseRunFactory(enrollment_end=past, course=self.course),
]
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url = '{}?marketable_enrollable_course_runs_with_archived={}'.format( url = '{}?marketable_enrollable_course_runs_with_archived={}'.format(
...@@ -126,12 +130,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -126,12 +130,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
assert response.data == self.serialize_course(
self.course, if marketable_enrollable_course_runs_with_archived:
extra_context={ # Emulate prefetching behavior.
'marketable_enrollable_course_runs_with_archived': marketable_enrollable_course_runs_with_archived for course_run in filtered_course_runs:
} course_run.delete()
)
assert response.data == self.serialize_course(self.course)
@ddt.data(1, 0) @ddt.data(1, 0)
def test_get_include_published_course_run(self, published_course_runs_only): def test_get_include_published_course_run(self, published_course_runs_only):
...@@ -140,15 +145,20 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -140,15 +145,20 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
the 'published_course_runs_only' flag is set to True the 'published_course_runs_only' flag is set to True
""" """
CourseRunFactory(status=CourseRunStatus.Published, course=self.course) CourseRunFactory(status=CourseRunStatus.Published, course=self.course)
CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course) unpublished_course_run = CourseRunFactory(status=CourseRunStatus.Unpublished, course=self.course)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url = '{}?published_course_runs_only={}'.format(url, published_course_runs_only) url = '{}?published_course_runs_only={}'.format(url, published_course_runs_only)
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual( assert response.status_code == 200
response.data,
self.serialize_course(self.course, extra_context={'published_course_runs_only': published_course_runs_only}) if published_course_runs_only:
) # Emulate prefetching behavior.
unpublished_course_run.delete()
assert response.data == self.serialize_course(self.course)
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
...@@ -170,7 +180,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -170,7 +180,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query) url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(58): with self.assertNumQueries(37):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
...@@ -91,7 +91,12 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -91,7 +91,12 @@ class CatalogViewSet(viewsets.ModelViewSet):
""" """
catalog = self.get_object() catalog = self.get_object()
queryset = catalog.courses().available() queryset = catalog.courses().available()
queryset = serializers.CatalogCourseSerializer.prefetch_queryset(queryset=queryset) course_runs = CourseRun.objects.active().enrollable().marketable()
queryset = serializers.CatalogCourseSerializer.prefetch_queryset(
queryset=queryset,
course_runs=course_runs
)
page = self.paginate_queryset(queryset) page = self.paginate_queryset(queryset)
serializer = serializers.CatalogCourseSerializer(page, many=True, context={'request': request}) serializer = serializers.CatalogCourseSerializer(page, many=True, context={'request': request})
......
...@@ -40,7 +40,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet): ...@@ -40,7 +40,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
paramType: query paramType: query
multiple: false multiple: false
""" """
q = self.request.query_params.get('q', None) q = self.request.query_params.get('q')
partner = self.get_partner() partner = self.get_partner()
if q: if q:
...@@ -50,9 +50,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet): ...@@ -50,9 +50,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
return qs return qs
else: else:
queryset = super(CourseRunViewSet, self).get_queryset().filter(course__partner=partner) queryset = super(CourseRunViewSet, self).get_queryset().filter(course__partner=partner)
queryset = queryset.select_related(*serializers.SELECT_RELATED_FIELDS['course_run']) return self.get_serializer_class().prefetch_queryset(queryset=queryset)
queryset = queryset.prefetch_related(*serializers.PREFETCH_FIELDS['course_run'])
return queryset
def get_serializer_context(self, *args, **kwargs): def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs) context = super().get_serializer_context(*args, **kwargs)
......
...@@ -6,8 +6,9 @@ from rest_framework.permissions import IsAuthenticated ...@@ -6,8 +6,9 @@ from rest_framework.permissions import IsAuthenticated
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import ProxiedPagination from course_discovery.apps.api.pagination import ProxiedPagination
from course_discovery.apps.api.v1.views import get_query_param from course_discovery.apps.api.v1.views import get_query_param
from course_discovery.apps.course_metadata.choices import CourseRunStatus
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX
from course_discovery.apps.course_metadata.models import Course from course_discovery.apps.course_metadata.models import Course, CourseRun
# pylint: disable=no-member # pylint: disable=no-member
...@@ -26,19 +27,34 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -26,19 +27,34 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class = ProxiedPagination pagination_class = ProxiedPagination
def get_queryset(self): def get_queryset(self):
q = self.request.query_params.get('q', None) q = self.request.query_params.get('q')
if q: if q:
queryset = Course.search(q) queryset = Course.search(q)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset)
else: else:
queryset = self.get_serializer_class().prefetch_queryset() course_runs = CourseRun.objects.exclude(hidden=True)
if get_query_param(self.request, 'marketable_course_runs_only'):
course_runs = course_runs.marketable().active()
if get_query_param(self.request, 'marketable_enrollable_course_runs_with_archived'):
course_runs = course_runs.marketable().enrollable()
if get_query_param(self.request, 'published_course_runs_only'):
course_runs = course_runs.filter(status=CourseRunStatus.Published)
queryset = self.get_serializer_class().prefetch_queryset(
queryset=self.queryset,
course_runs=course_runs
)
return queryset.order_by(Lower('key')) return queryset.order_by(Lower('key'))
def get_serializer_context(self, *args, **kwargs): def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs) context = super().get_serializer_context(*args, **kwargs)
query_params = ['exclude_utm', 'include_deleted_programs', 'marketable_course_runs_only', query_params = ['exclude_utm', 'include_deleted_programs']
'marketable_enrollable_course_runs_with_archived', 'published_course_runs_only']
for query_param in query_params: for query_param in query_params:
context[query_param] = get_query_param(self.request, query_param) context[query_param] = get_query_param(self.request, query_param)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment