Commit 0a3b0dfd by Clinton Blackburn Committed by GitHub

Improved performance of program API endpoint (#324)

ECOM-5559
parent 838d3330
...@@ -2,7 +2,7 @@ import ddt ...@@ -2,7 +2,7 @@ import ddt
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from rest_framework.test import APITestCase, APIRequestFactory from rest_framework.test import APITestCase, APIRequestFactory
from course_discovery.apps.api.serializers import ProgramSerializer from course_discovery.apps.api.serializers import ProgramSerializer, MinimalProgramSerializer
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.choices import ProgramStatus from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Program from course_discovery.apps.course_metadata.models import Program
...@@ -40,15 +40,17 @@ class ProgramViewSetTests(APITestCase): ...@@ -40,15 +40,17 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = ProgramFactory() program = ProgramFactory()
self.assert_retrieve_success(program) with self.assertNumQueries(15):
self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self): def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """ """ Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory() course = CourseFactory()
program = ProgramFactory(courses=[course]) program = ProgramFactory(courses=[course])
self.assert_retrieve_success(program) with self.assertNumQueries(15):
self.assert_retrieve_success(program)
def assert_list_results(self, url, expected): def assert_list_results(self, url, expected, expected_query_count):
""" """
Asserts the results serialized/returned at the URL matches those that are expected. Asserts the results serialized/returned at the URL matches those that are expected.
Args: Args:
...@@ -62,27 +64,29 @@ class ProgramViewSetTests(APITestCase): ...@@ -62,27 +64,29 @@ class ProgramViewSetTests(APITestCase):
Returns: Returns:
None None
""" """
response = self.client.get(url) with self.assertNumQueries(expected_query_count):
response = self.client.get(url)
self.assertEqual( self.assertEqual(
response.data['results'], response.data['results'],
ProgramSerializer(expected, many=True, context={'request': self.request}).data MinimalProgramSerializer(expected, many=True, context={'request': self.request}).data
) )
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3) expected = ProgramFactory.create_batch(3)
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected) self.assert_list_results(self.list_path, expected, 7)
def test_filter_by_type(self): def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """ """ Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo' program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name) program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program]) self.assert_list_results(url, [program], 7)
url = self.list_path + '?type=bar' url = self.list_path + '?type=bar'
self.assert_list_results(url, []) self.assert_list_results(url, [], 4)
def test_filter_by_uuids(self): def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """ """ Verify that the endpoint filters programs to those matching the provided UUIDs. """
...@@ -94,14 +98,14 @@ class ProgramViewSetTests(APITestCase): ...@@ -94,14 +98,14 @@ class ProgramViewSetTests(APITestCase):
# Create a third program, which should be filtered out. # Create a third program, which should be filtered out.
ProgramFactory() ProgramFactory()
self.assert_list_results(url, expected) self.assert_list_results(url, expected, 7)
@ddt.data( @ddt.data(
(ProgramStatus.Unpublished, False), (ProgramStatus.Unpublished, False, 4),
(ProgramStatus.Active, True), (ProgramStatus.Active, True, 7),
) )
@ddt.unpack @ddt.unpack
def test_filter_by_marketable(self, status, is_marketable): def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """ """ Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1' url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='') ProgramFactory(marketing_slug='')
...@@ -110,4 +114,4 @@ class ProgramViewSetTests(APITestCase): ...@@ -110,4 +114,4 @@ class ProgramViewSetTests(APITestCase):
expected = programs if is_marketable else [] expected = programs if is_marketable else []
self.assertEqual(list(Program.objects.marketable()), expected) self.assertEqual(list(Program.objects.marketable()), expected)
self.assert_list_results(url, expected) self.assert_list_results(url, expected, expected_query_count)
...@@ -395,7 +395,6 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -395,7 +395,6 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
lookup_value_regex = '[0-9a-f-]+' lookup_value_regex = '[0-9a-f-]+'
queryset = prefetch_related_objects_for_programs(Program.objects.all()) queryset = prefetch_related_objects_for_programs(Program.objects.all())
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = serializers.ProgramSerializer
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter filter_class = filters.ProgramFilter
...@@ -404,6 +403,12 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -404,6 +403,12 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0)) context['published_course_runs_only'] = int(self.request.GET.get('published_course_runs_only', 0))
return context return context
def get_serializer_class(self):
if self.action == 'list':
return serializers.MinimalProgramSerializer
return serializers.ProgramSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
""" List all programs. """ List all programs.
--- ---
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment