Commit cda1efbe by Vedran Karacic Committed by Afzal Wali Naushahi

Filter API data by partner

LEARNER-1119
parent a6365d68
import logging
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db.models import QuerySet
from django.utils.translation import ugettext as _
......@@ -13,7 +12,6 @@ from guardian.shortcuts import get_objects_for_user
from rest_framework.exceptions import NotFound, PermissionDenied
from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Organization, Program
......@@ -93,7 +91,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
# Return data for the default partner, if no partner is requested
if not any(field in filters for field in ('partner', 'partner_exact')):
filters['partner'] = Partner.objects.get(pk=settings.DEFAULT_PARTNER_ID).short_code
filters['partner'] = request.site.partner.short_code
return filters
......
......@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
tags = TagListSerializerField()
@classmethod
def prefetch_queryset(cls):
return Organization.objects.all().select_related('partner').prefetch_related('tags')
def prefetch_queryset(cls, partner):
return Organization.objects.filter(partner=partner).select_related('partner').prefetch_related('tags')
class Meta(MinimalOrganizationSerializer.Meta):
fields = MinimalOrganizationSerializer.Meta.fields + (
......@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
def prefetch_queryset(cls, queryset=None, course_runs=None, partner=None):
# Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.all()
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
)
class Meta(MinimalCourseSerializer.Meta):
......@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
programs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
def prefetch_queryset(cls, queryset=None, course_runs=None, partner=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.all()
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
)
def get_course_runs(self, course):
......@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
def prefetch_queryset(cls, queryset=None, course_runs=None, partner=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.all()
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
)
def get_course_runs(self, course):
......@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all())
@classmethod
def prefetch_queryset(cls):
return Program.objects.all().select_related('type', 'partner').prefetch_related(
def prefetch_queryset(cls, partner):
return Program.objects.filter(partner=partner).select_related('type', 'partner').prefetch_related(
'excluded_course_runs',
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
......@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
applicable_seat_types = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls):
def prefetch_queryset(cls, partner):
"""
Prefetch the related objects that will be serialized with a `Program`.
......@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching).
"""
return Program.objects.all().select_related('type', 'video', 'partner').prefetch_related(
return Program.objects.filter(partner=partner).select_related('type', 'video', 'partner').prefetch_related(
'excluded_course_runs',
'expected_learning_items',
'faq',
......@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
'type__applicable_seat_types',
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()),
)
......
from django.conf import settings
from django.contrib.sites.models import Site
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory
class PartnerMixin(object):
def setUp(self):
super(PartnerMixin, self).setUp()
Site.objects.all().delete()
self.site = SiteFactory(id=settings.SITE_ID)
self.partner = PartnerFactory(site=self.site)
......@@ -21,6 +21,7 @@ from course_discovery.apps.api.serializers import (
ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer,
TypeaheadProgramSearchSerializer, VideoSerializer
)
from course_discovery.apps.api.tests.mixins import PartnerMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory
......@@ -96,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
class MinimalCourseSerializerTests(TestCase):
class MinimalCourseSerializerTests(PartnerMixin, TestCase):
serializer_class = MinimalCourseSerializer
def get_expected_data(self, course, request):
......@@ -113,8 +114,8 @@ class MinimalCourseSerializerTests(TestCase):
def test_data(self):
request = make_request()
organizations = OrganizationFactory()
course = CourseFactory(authoring_organizations=[organizations])
organizations = OrganizationFactory(partner=self.partner)
course = CourseFactory(authoring_organizations=[organizations], partner=self.partner)
CourseRunFactory.create_batch(2, course=course)
serializer = self.serializer_class(course, context={'request': request})
expected = self.get_expected_data(course, request)
......@@ -177,9 +178,10 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
def setUp(self):
super().setUp()
self.request = make_request()
self.course = CourseFactory()
self.course = CourseFactory(partner=self.partner)
self.deleted_program = ProgramFactory(
courses=[self.course],
partner=self.partner,
status=ProgramStatus.Deleted
)
......
import ddt
import pytest
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import PermissionDenied
from django.test import RequestFactory, TestCase
from django.urls import reverse
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.views import api_docs_permission_denied_handler
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
@pytest.mark.django_db
class TestApiDocs:
class TestApiDocs(APITestCase):
"""
Regression tests introduced following LEARNER-1590.
"""
path = reverse('api_docs')
def test_api_docs(self, admin_client):
def test_api_docs(self):
"""
Verify that the API docs are available to authenticated clients.
"""
PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
response = admin_client.get(self.path)
user = UserFactory(is_staff=True)
self.client.login(username=user.username, password=USER_PASSWORD)
response = self.client.get(self.path)
assert response.status_code == 200
def test_api_docs_redirect(self, client):
def test_api_docs_redirect(self):
"""
Verify that unauthenticated clients are redirected.
"""
response = client.get(self.path)
response = self.client.get(self.path)
assert response.status_code == 302
......
......@@ -4,6 +4,7 @@ import json
import responses
from django.conf import settings
from rest_framework.test import APITestCase as RestAPITestCase
from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.serializers import (
......@@ -11,6 +12,7 @@ from course_discovery.apps.api.serializers import (
CourseWithProgramsSerializer, FlattenedCourseRunWithCourseSerializer, MinimalProgramSerializer,
OrganizationSerializer, PersonSerializer, ProgramSerializer, ProgramTypeSerializer
)
from course_discovery.apps.api.tests.mixins import PartnerMixin
class SerializationMixin(object):
......@@ -88,3 +90,7 @@ class OAuth2Mixin(object):
content_type='application/json',
status=status
)
class APITestCase(PartnerMixin, RestAPITestCase):
pass
......@@ -46,7 +46,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
def test_affiliate_with_supported_seats(self):
""" Verify that endpoint returns course runs for verified and professional seats only. """
with self.assertNumQueries(8):
with self.assertNumQueries(9):
response = self.client.get(self.affiliate_url)
self.assertEqual(response.status_code, 200)
......@@ -130,7 +130,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
# Superusers can view all catalogs
self.client.force_authenticate(superuser)
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
......
......@@ -185,8 +185,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# Any course appearing in the response must have at least one serialized run.
assert len(response.data['results'][0]['course_runs']) > 0
else:
with self.assertNumQueries(2):
response = self.client.get(url)
response = self.client.get(url)
assert response.status_code == 200
assert response.data['results'] == []
......@@ -218,7 +217,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id})
with self.assertNumQueries(17):
with self.assertNumQueries(18):
response = self.client.get(url)
course_run = self.serialize_catalog_flat_course_run(self.course_run)
......
......@@ -4,19 +4,16 @@ import urllib
import ddt
import pytz
from django.conf import settings
from django.db.models.functions import Lower
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory, APITestCase
from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import (
CourseRunFactory, PartnerFactory, ProgramFactory, SeatFactory
)
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory, SeatFactory
@ddt.ddt
......@@ -25,10 +22,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
super(CourseRunViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
self.course_run = CourseRunFactory(course__partner=self.partner)
self.course_run_2 = CourseRunFactory(course__partner=self.partner)
self.refresh_index()
......@@ -170,15 +163,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted)
def test_list_query_invalid_partner(self):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query = 'title:Some random title'
url = '{root}?q={query}&partner={partner}'.format(root=reverse('api:v1:course_run-list'), query=query,
partner='foo')
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def assert_list_results(self, url, expected, extra_context=None):
expected = sorted(expected, key=lambda course_run: course_run.key.lower())
response = self.client.get(url)
......@@ -268,18 +252,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
}
)
def test_contains_single_course_run_invalid_partner(self):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs = urllib.parse.urlencode({
'query': 'id:course*',
'course_run_ids': self.course_run.key,
'partner': 'foo'
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url)
assert response.status_code == 400
def test_contains_multiple_course_runs(self):
qs = urllib.parse.urlencode({
'query': 'id:course*',
......
......@@ -4,9 +4,8 @@ import ddt
import pytz
from django.db.models.functions import Lower
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import Course
......@@ -23,13 +22,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
super(CourseViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.course = CourseFactory()
self.course = CourseFactory(partner=self.partner)
def test_get(self):
""" Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(18):
with self.assertNumQueries(20):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course))
......@@ -38,7 +37,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(11):
with self.assertNumQueries(13):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), [])
......@@ -51,7 +50,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1'
with self.assertNumQueries(22):
with self.assertNumQueries(24):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(
......@@ -187,7 +186,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
with self.assertNumQueries(24):
with self.assertNumQueries(26):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(
......@@ -203,18 +202,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(37):
with self.assertNumQueries(39):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
def test_list_key_filter(self):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
courses = CourseFactory.create_batch(3)
courses = CourseFactory.create_batch(3, partner=self.partner)
courses = sorted(courses, key=lambda course: course.key.lower())
keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(37):
with self.assertNumQueries(39):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
import uuid
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import Organization, OrganizationFactory
......@@ -27,17 +26,19 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def assert_response_data_valid(self, response, organizations, many=True):
""" Asserts the response data (only) contains the expected organizations. """
actual = response.data
serializer_data = self.serialize_organization(organizations, many=many)
if many:
actual = actual['results']
actual = sorted(actual, key=lambda k: k['uuid'])
serializer_data = sorted(serializer_data, key=lambda k: k['uuid'])
self.assertEqual(actual, self.serialize_organization(organizations, many=many))
self.assertEqual(actual, serializer_data)
def assert_list_uuid_filter(self, organizations):
def assert_list_uuid_filter(self, organizations, expected_query_count):
""" Asserts the list endpoint supports filtering by UUID. """
with self.assertNumQueries(5):
with self.assertNumQueries(expected_query_count):
uuids = ','.join([organization.uuid.hex for organization in organizations])
url = '{root}?uuids={uuids}'.format(root=self.list_path, uuids=uuids)
response = self.client.get(url)
......@@ -47,7 +48,6 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def assert_list_tag_filter(self, organizations, tags, expected_query_count=5):
""" Asserts the list endpoint supports filtering by tags. """
with self.assertNumQueries(expected_query_count):
tags = ','.join(tags)
url = '{root}?tags={tags}'.format(root=self.list_path, tags=tags)
......@@ -58,10 +58,9 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list(self):
""" Verify the endpoint returns a list of all organizations. """
OrganizationFactory.create_batch(3, partner=self.partner)
OrganizationFactory.create_batch(3)
with self.assertNumQueries(5):
with self.assertNumQueries(8):
response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 200)
......@@ -70,22 +69,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list_uuid_filter(self):
""" Verify the endpoint returns a list of organizations filtered by UUID. """
organizations = OrganizationFactory.create_batch(3)
organizations = OrganizationFactory.create_batch(3, partner=self.partner)
# Test with a single UUID
self.assert_list_uuid_filter([organizations[0]])
self.assert_list_uuid_filter([organizations[0]], 7)
# Test with multiple UUIDs
self.assert_list_uuid_filter(organizations)
self.assert_list_uuid_filter(organizations, 5)
def test_list_tag_filter(self):
""" Verify the endpoint returns a list of organizations filtered by tag. """
tag = 'test-org'
organizations = OrganizationFactory.create_batch(2)
organizations = OrganizationFactory.create_batch(2, partner=self.partner)
# If no organizations have been tagged, the endpoint should not return any data
self.assert_list_tag_filter([], [tag], expected_query_count=3)
self.assert_list_tag_filter([], [tag], expected_query_count=5)
# Tagged organizations should be returned
organizations[0].tags.add(tag)
......@@ -99,7 +98,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns details for a single organization. """
organization = OrganizationFactory()
organization = OrganizationFactory(partner=self.partner)
url = reverse('api:v1:organization-detail', kwargs={'uuid': organization.uuid})
response = self.client.get(url)
......
# pylint: disable=redefined-builtin,no-member
import ddt
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import IntegrityError
from mock import mock
......@@ -8,20 +7,19 @@ from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from testfixtures import LogCapture
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import PartnerMixin, SerializationMixin
from course_discovery.apps.api.v1.views.people import logger as people_logger
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.models import Person
from course_discovery.apps.course_metadata.people import MarketingSitePeople
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import (OrganizationFactory, PartnerFactory, PersonFactory,
PositionFactory)
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory, PositionFactory
User = get_user_model()
@ddt.ddt
class PersonViewSetTests(SerializationMixin, APITestCase):
class PersonViewSetTests(SerializationMixin, PartnerMixin, APITestCase):
""" Tests for the person resource. """
people_list_url = reverse('api:v1:person-list')
......@@ -32,10 +30,6 @@ class PersonViewSetTests(SerializationMixin, APITestCase):
self.person = PersonFactory()
PositionFactory(person=self.person)
self.organization = OrganizationFactory()
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
toggle_switch('publish_person_to_marketing_site', True)
self.expected_node = {
'resource': 'node', ''
......
......@@ -28,7 +28,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all program types. """
ProgramTypeFactory.create_batch(4)
expected = ProgramType.objects.all()
with self.assertNumQueries(5):
with self.assertNumQueries(6):
response = self.client.get(self.list_path)
assert response.status_code == 200
......@@ -39,7 +39,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
program_type = ProgramTypeFactory()
url = reverse('api:v1:program_type-detail', kwargs={'slug': program_type.slug})
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response = self.client.get(url)
assert response.status_code == 200
......
......@@ -3,10 +3,9 @@ import urllib.parse
import ddt
from django.core.cache import cache
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import MinimalProgramSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.views.programs import ProgramViewSet
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
......@@ -31,10 +30,10 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
cache.clear()
def create_program(self):
organizations = [OrganizationFactory()]
organizations = [OrganizationFactory(partner=self.partner)]
person = PersonFactory()
course = CourseFactory()
course = CourseFactory(partner=self.partner)
CourseRunFactory(course=course, staff=[person])
program = ProgramFactory(
......@@ -46,7 +45,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
expected_learning_items=ExpectedLearningItemFactory.create_batch(1),
job_outlook_items=JobOutlookItemFactory.create_batch(1),
banner_image=make_image_file('test_banner.jpg'),
video=VideoFactory()
video=VideoFactory(),
partner=self.partner
)
return program
......@@ -73,7 +73,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """
program = self.create_program()
with self.assertNumQueries(37):
with self.assertNumQueries(39):
response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
......@@ -90,22 +90,25 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
@ddt.data(True, False)
def test_retrieve_with_sorting_flag(self, order_courses_by_start_date):
""" Verify the number of queries is the same with sorting flag set to true. """
course_list = CourseFactory.create_batch(3)
course_list = CourseFactory.create_batch(3, partner=self.partner)
for course in course_list:
CourseRunFactory(course=course)
program = ProgramFactory(courses=course_list, order_courses_by_start_date=order_courses_by_start_date)
program = ProgramFactory(
courses=course_list,
order_courses_by_start_date=order_courses_by_start_date,
partner=self.partner)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
with self.assertNumQueries(26):
with self.assertNumQueries(28):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
self.assertEqual(course_list, list(program.courses.all())) # pylint: disable=no-member
def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory()
program = ProgramFactory(courses=[course])
with self.assertNumQueries(20):
course = CourseFactory(partner=self.partner)
program = ProgramFactory(courses=[course], partner=self.partner)
with self.assertNumQueries(22):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
......@@ -135,7 +138,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all programs. """
expected = [self.create_program() for __ in range(3)]
expected.reverse()
self.assert_list_results(self.list_path, expected, 12)
self.assert_list_results(self.list_path, expected, 14)
# Verify that repeated list requests use the cache.
self.assert_list_results(self.list_path, expected, 2)
......@@ -145,8 +148,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
Verify that the list view returns a simply list of UUIDs when the
uuids_only query parameter is passed.
"""
active = ProgramFactory.create_batch(3)
retired = [ProgramFactory(status=ProgramStatus.Retired)]
active = ProgramFactory.create_batch(3, partner=self.partner)
retired = [ProgramFactory(status=ProgramStatus.Retired, partner=self.partner)]
programs = active + retired
querystring = {'uuids_only': 1}
......@@ -165,47 +168,47 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name)
program = ProgramFactory(type__name=program_type_name, partner=self.partner)
url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program], 8)
self.assert_list_results(url, [program], 10)
url = self.list_path + '?type=bar'
self.assert_list_results(url, [], 3)
def test_filter_by_types(self):
""" Verify that the endpoint filters programs to those matching the provided ProgramType slugs. """
expected = ProgramFactory.create_batch(2)
expected = ProgramFactory.create_batch(2, partner=self.partner)
expected.reverse()
type_slugs = [p.type.slug for p in expected]
url = self.list_path + '?types=' + ','.join(type_slugs)
# Create a third program, which should be filtered out.
ProgramFactory()
ProgramFactory(partner=self.partner)
self.assert_list_results(url, expected, 8)
self.assert_list_results(url, expected, 10)
def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
expected = ProgramFactory.create_batch(2)
expected = ProgramFactory.create_batch(2, partner=self.partner)
expected.reverse()
uuids = [str(p.uuid) for p in expected]
url = self.list_path + '?uuids=' + ','.join(uuids)
# Create a third program, which should be filtered out.
ProgramFactory()
ProgramFactory(partner=self.partner)
self.assert_list_results(url, expected, 8)
self.assert_list_results(url, expected, 10)
@ddt.data(
(ProgramStatus.Unpublished, False, 3),
(ProgramStatus.Active, True, 8),
(ProgramStatus.Unpublished, False, 5),
(ProgramStatus.Active, True, 10),
)
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='')
programs = ProgramFactory.create_batch(3, status=status)
ProgramFactory(marketing_slug='', partner=self.partner)
programs = ProgramFactory.create_batch(3, status=status, partner=self.partner)
programs.reverse()
expected = programs if is_marketable else []
......@@ -214,11 +217,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_status(self):
""" Verify the endpoint allows programs to filtered by one, or more, statuses. """
active = ProgramFactory(status=ProgramStatus.Active)
retired = ProgramFactory(status=ProgramStatus.Retired)
active = ProgramFactory(status=ProgramStatus.Active, partner=self.partner)
retired = ProgramFactory(status=ProgramStatus.Retired, partner=self.partner)
url = self.list_path + '?status=active'
self.assert_list_results(url, [active], 8)
self.assert_list_results(url, [active], 10)
url = self.list_path + '?status=retired'
self.assert_list_results(url, [retired], 8)
......@@ -228,11 +231,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_hidden(self):
""" Endpoint should filter programs by their hidden attribute value. """
hidden = ProgramFactory(hidden=True)
not_hidden = ProgramFactory(hidden=False)
hidden = ProgramFactory(hidden=True, partner=self.partner)
not_hidden = ProgramFactory(hidden=False, partner=self.partner)
url = self.list_path + '?hidden=True'
self.assert_list_results(url, [hidden], 8)
self.assert_list_results(url, [hidden], 10)
url = self.list_path + '?hidden=False'
self.assert_list_results(url, [not_hidden], 8)
......@@ -247,7 +250,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns marketing URLs without UTM parameters. """
url = self.list_path + '?exclude_utm=1'
program = self.create_program()
self.assert_list_results(url, [program], 12, extra_context={'exclude_utm': 1})
self.assert_list_results(url, [program], 14, extra_context={'exclude_utm': 1})
def test_minimal_serializer_use(self):
""" Verify that the list view uses the minimal serializer. """
......
......@@ -3,13 +3,12 @@ import json
import urllib.parse
import ddt
from django.conf import settings
from django.urls import reverse
from haystack.query import SearchQuerySet
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer)
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.views.search import TypeaheadSearchView
from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
......@@ -88,14 +87,8 @@ class SynonymTestMixin:
self.assertDictEqual(response1, response2)
class DefaultPartnerMixin:
def setUp(self):
super(DefaultPartnerMixin, self).setUp()
self.partner = PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
@ddt.ddt
class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
""" Tests for CourseRunSearchViewSet. """
faceted_path = reverse('api:v1:search-course_runs-facets')
......@@ -271,7 +264,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
)
self.reindex_courses(program)
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
......@@ -295,7 +288,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory(courses=[course_run.course], status=program_status)
self.reindex_courses(active_program)
with self.assertNumQueries(5):
with self.assertNumQueries(6):
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
......@@ -313,7 +306,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
@ddt.ddt
class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets')
......@@ -438,7 +431,7 @@ class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
assert expected == actual
class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-typeahead')
......@@ -620,23 +613,3 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
self.serialize_program(harvard_program)]
}
self.assertDictEqual(response.data, expected)
def test_typeahead_partner_filter(self):
""" Ensure that a partner param limits results to that partner. """
course_runs = []
programs = []
for partner in ['edx', 'other']:
title = 'Belongs to partner ' + partner
partner = PartnerFactory(short_code=partner)
course_runs.append(CourseRunFactory(title=title, course=CourseFactory(partner=partner)))
programs.append(ProgramFactory(
title=title, partner=partner,
status=ProgramStatus.Active
))
response = self.get_response({'q': 'partner'}, 'edx')
self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0]
edx_program = programs[0]
self.assertDictEqual(response.data, {'course_runs': [self.serialize_course_run(edx_course_run)],
'programs': [self.serialize_program(edx_program)]})
......@@ -35,18 +35,3 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
......@@ -9,14 +9,13 @@ from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import ProxiedPagination
from course_discovery.apps.api.utils import get_query_param
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member
class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
class CourseRunViewSet(viewsets.ModelViewSet):
""" CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter
......@@ -43,7 +42,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
multiple: false
"""
q = self.request.query_params.get('q')
partner = self.get_partner()
partner = self.request.site.partner
if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
......@@ -80,12 +79,6 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
type: string
paramType: query
multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
- name: hidden
description: Filter based on wether the course run is hidden from search.
required: false
......@@ -166,7 +159,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids')
partner = self.get_partner()
partner = self.request.site.partner
if query and course_run_ids:
course_run_ids = course_run_ids.split(',')
......
......@@ -18,7 +18,6 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
filter_class = filters.CourseFilter
lookup_field = 'key'
lookup_value_regex = COURSE_ID_REGEX
queryset = Course.objects.all()
permission_classes = (IsAuthenticated,)
serializer_class = serializers.CourseWithProgramsSerializer
......@@ -27,16 +26,17 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class = ProxiedPagination
def get_queryset(self):
partner = self.request.site.partner
q = self.request.query_params.get('q')
if q:
queryset = Course.search(q)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset, partner=partner)
else:
if get_query_param(self.request, 'include_hidden_course_runs'):
course_runs = CourseRun.objects.all()
course_runs = CourseRun.objects.filter(course__partner=partner)
else:
course_runs = CourseRun.objects.exclude(hidden=True)
course_runs = CourseRun.objects.filter(course__partner=partner).exclude(hidden=True)
if get_query_param(self.request, 'marketable_course_runs_only'):
course_runs = course_runs.marketable().active()
......@@ -49,7 +49,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
queryset = self.get_serializer_class().prefetch_queryset(
queryset=self.queryset,
course_runs=course_runs
course_runs=course_runs,
partner=partner
)
return queryset.order_by(Lower('key'))
......
......@@ -15,13 +15,16 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+'
permission_classes = (IsAuthenticated,)
queryset = serializers.OrganizationSerializer.prefetch_queryset()
serializer_class = serializers.OrganizationSerializer
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# versions of this API should only support the system default, PageNumberPagination.
pagination_class = ProxiedPagination
def get_queryset(self):
partner = self.request.site.partner
return serializers.OrganizationSerializer.prefetch_queryset(partner=partner)
def list(self, request, *args, **kwargs):
""" Retrieve a list of all organizations. """
return super(OrganizationViewSet, self).list(request, *args, **kwargs)
......
......@@ -7,7 +7,6 @@ from rest_framework.response import Response
from course_discovery.apps.api import serializers
from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.exceptions import MarketingSiteAPIClientException, PersonToMarketingException
from course_discovery.apps.course_metadata.people import MarketingSitePeople
......@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=no-member
class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
class PersonViewSet(viewsets.ModelViewSet):
""" PersonSerializer resource. """
lookup_field = 'uuid'
......@@ -30,7 +29,7 @@ class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
""" Create a new person. """
person_data = request.data
partner = self.get_partner()
partner = request.site.partner
person_data['partner'] = partner.id
serializer = self.get_serializer(data=person_data)
serializer.is_valid(raise_exception=True)
......
......@@ -32,7 +32,8 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.get_serializer_class().prefetch_queryset()
partner = self.request.site.partner
return self.get_serializer_class().prefetch_queryset(partner)
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
......@@ -89,7 +90,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
if get_query_param(self.request, 'uuids_only'):
# DRF serializers don't have good support for simple, flat
# representations like the one we want here.
queryset = self.filter_queryset(Program.objects.all())
queryset = self.filter_queryset(Program.objects.filter(partner=self.request.site.partner))
uuids = queryset.values_list('uuid', flat=True)
return Response(uuids)
......
......@@ -12,7 +12,6 @@ from rest_framework.response import Response
from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
......@@ -119,7 +118,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(PartnerMixin, APIView):
class TypeaheadSearchView(APIView):
""" Typeahead for courses and programs. """
RESULT_COUNT = 3
permission_classes = (IsAuthenticated,)
......@@ -181,7 +180,7 @@ class TypeaheadSearchView(PartnerMixin, APIView):
type: string
"""
query = request.query_params.get('q')
partner = self.get_partner()
partner = request.site.partner
if not query:
raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query, partner)
......
......@@ -10,6 +10,7 @@ from django.urls import reverse
from django.utils.encoding import force_text
from course_discovery.apps.core.constants import Status
from course_discovery.apps.core.views import get_database_status
User = get_user_model()
......@@ -17,13 +18,24 @@ User = get_user_model()
class HealthTests(TestCase):
"""Tests of the health endpoint."""
def test_getting_database_ok_status(self):
"""Method should return the OK status."""
status = get_database_status()
self.assertEqual(status, Status.OK)
def test_getting_database_unavailable_status(self):
"""Method should return the unavailable status when a DatabaseError occurs."""
with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError):
status = get_database_status()
self.assertEqual(status, Status.UNAVAILABLE)
def test_all_services_available(self):
"""Test that the endpoint reports when all services are healthy."""
self._assert_health(200, Status.OK, Status.OK)
def test_database_outage(self):
"""Test that the endpoint reports when the database is unavailable."""
with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError):
with mock.patch('course_discovery.apps.core.views.get_database_status', return_value=Status.UNAVAILABLE):
self._assert_health(503, Status.UNAVAILABLE, Status.UNAVAILABLE)
def _assert_health(self, status_code, overall_status, database_status):
......
......@@ -15,6 +15,18 @@ logger = logging.getLogger(__name__)
User = get_user_model()
def get_database_status():
"""Run a database query to see if the database is responsive."""
try:
cursor = connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
cursor.close()
return Status.OK
except DatabaseError:
return Status.UNAVAILABLE
@transaction.non_atomic_requests
def health(_):
"""Allows a load balancer to verify this service is up.
......@@ -32,15 +44,7 @@ def health(_):
>>> response.content
'{"overall_status": "OK", "detailed_status": {"database_status": "OK"}}'
"""
try:
cursor = connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
cursor.close()
database_status = Status.OK
except DatabaseError:
database_status = Status.UNAVAILABLE
database_status = get_database_status()
overall_status = Status.OK if (database_status == Status.OK) else Status.UNAVAILABLE
......
......@@ -2,17 +2,17 @@ import datetime
import urllib.parse
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.tests.test_views.test_search import (
DefaultPartnerMixin, ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin
ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin
)
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory, ProgramFactory
from course_discovery.apps.edx_catalog_extensions.api.serializers import DistinctCountsAggregateFacetSearchSerializer
class DistinctCountsAggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin,
class DistinctCountsAggregateSearchViewSetTests(SerializationMixin, LoginMixin,
ElasticsearchTestMixin, SynonymTestMixin, APITestCase):
path = reverse('extensions:api:v1:search-all-facets')
......
......@@ -79,6 +79,7 @@ MIDDLEWARE_CLASSES = (
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.contrib.sites.middleware.CurrentSiteMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'social_django.middleware.SocialAuthExceptionMiddleware',
'waffle.middleware.WaffleMiddleware',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment