Commit 6b2de5c7 by Renzo Lucioni Committed by GitHub

Fix CorporateEndorsement prefetching (#350)

endorser is a field on Endorsement, not CorporateEndorsement. We should be prefetching individual_endorsements.
parent 35dc00f7
...@@ -229,7 +229,7 @@ class CorporateEndorsementSerializer(serializers.ModelSerializer): ...@@ -229,7 +229,7 @@ class CorporateEndorsementSerializer(serializers.ModelSerializer):
@classmethod @classmethod
def prefetch_queryset(cls): def prefetch_queryset(cls):
return CorporateEndorsement.objects.all().select_related('image').prefetch_related( return CorporateEndorsement.objects.all().select_related('image').prefetch_related(
Prefetch('endorser', queryset=EndorsementSerializer.prefetch_queryset()), Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()),
) )
class Meta(object): class Meta(object):
......
...@@ -4,9 +4,13 @@ from rest_framework.test import APITestCase, APIRequestFactory ...@@ -4,9 +4,13 @@ from rest_framework.test import APITestCase, APIRequestFactory
from course_discovery.apps.api.serializers import ProgramSerializer from course_discovery.apps.api.serializers import ProgramSerializer
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.choices import ProgramStatus from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Program from course_discovery.apps.course_metadata.models import Program
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory, CourseFactory from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, VideoFactory, OrganizationFactory, PersonFactory, ProgramFactory,
CorporateEndorsementFactory, EndorsementFactory, JobOutlookItemFactory, ExpectedLearningItemFactory
)
@ddt.ddt @ddt.ddt
...@@ -20,6 +24,26 @@ class ProgramViewSetTests(APITestCase): ...@@ -20,6 +24,26 @@ class ProgramViewSetTests(APITestCase):
self.request = APIRequestFactory().get('/') self.request = APIRequestFactory().get('/')
self.request.user = self.user self.request.user = self.user
def create_program(self):
organizations = [OrganizationFactory()]
person = PersonFactory()
course = CourseFactory()
CourseRunFactory(course=course, staff=[person])
program = ProgramFactory(
courses=[course],
authoring_organizations=organizations,
credit_backing_organizations=organizations,
corporate_endorsements=CorporateEndorsementFactory.create_batch(1),
individual_endorsements=EndorsementFactory.create_batch(1),
expected_learning_items=ExpectedLearningItemFactory.create_batch(1),
job_outlook_items=JobOutlookItemFactory.create_batch(1),
banner_image=make_image_file('test_banner.jpg'),
video=VideoFactory()
)
return program
def assert_retrieve_success(self, program): def assert_retrieve_success(self, program):
""" Verify the retrieve endpoint succesfully returns a serialized program. """ """ Verify the retrieve endpoint succesfully returns a serialized program. """
url = reverse('api:v1:program-detail', kwargs={'uuid': program.uuid}) url = reverse('api:v1:program-detail', kwargs={'uuid': program.uuid})
...@@ -39,8 +63,8 @@ class ProgramViewSetTests(APITestCase): ...@@ -39,8 +63,8 @@ class ProgramViewSetTests(APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = ProgramFactory() program = self.create_program()
with self.assertNumQueries(33): with self.assertNumQueries(89):
self.assert_retrieve_success(program) self.assert_retrieve_success(program)
def test_retrieve_without_course_runs(self): def test_retrieve_without_course_runs(self):
...@@ -74,9 +98,9 @@ class ProgramViewSetTests(APITestCase): ...@@ -74,9 +98,9 @@ class ProgramViewSetTests(APITestCase):
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3) expected = [self.create_program() for __ in range(3)]
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected, 14) self.assert_list_results(self.list_path, expected, 41)
def test_filter_by_type(self): def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """ """ Verify that the endpoint filters programs to those of a given type. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment