Commit 2e4b3726 by Michael Frey Committed by GitHub

Merge pull request #199 from edx/mjfrey/SOL-1924

Add partner filtering for course runs
parents a36feacf 48644be4
from rest_framework import status
from rest_framework.exceptions import APIException
class InvalidPartnerError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
...@@ -10,7 +10,7 @@ from course_discovery.apps.api.serializers import CourseRunSerializer ...@@ -10,7 +10,7 @@ from course_discovery.apps.api.serializers import CourseRunSerializer
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, PartnerFactory
@ddt.ddt @ddt.ddt
...@@ -19,8 +19,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -19,8 +19,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
super(CourseRunViewSetTests, self).setUp() super(CourseRunViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user) self.client.force_authenticate(self.user)
self.course_run = CourseRunFactory() self.default_partner = PartnerFactory()
self.course_run_2 = CourseRunFactory() self.course_run = CourseRunFactory(course__partner=self.default_partner)
self.course_run_2 = CourseRunFactory(course__partner=self.default_partner)
self.refresh_index() self.refresh_index()
self.request = APIRequestFactory().get('/') self.request = APIRequestFactory().get('/')
self.request.user = self.user self.request.user = self.user
...@@ -49,10 +50,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -49,10 +50,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
def test_list_query(self): def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """ """ Verify the endpoint returns a filtered list of courses """
title = 'Some random title' course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.default_partner)
course_runs = CourseRunFactory.create_batch(3, title=title)
CourseRunFactory(title='non-matching name') CourseRunFactory(title='non-matching name')
query = 'title:' + title query = 'title:Some random title'
url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query) url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query)
response = self.client.get(url) response = self.client.get(url)
...@@ -61,6 +61,15 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -61,6 +61,15 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
key=lambda course_run: course_run['key']) key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted) self.assertListEqual(actual_sorted, expected_sorted)
def test_list_query_invalid_partner(self):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query = 'title:Some random title'
url = '{root}?q={query}&partner={partner}'.format(root=reverse('api:v1:course_run-list'), query=query,
partner='foo')
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def test_list_key_filter(self): def test_list_key_filter(self):
""" Verify the endpoint returns a list of course runs filtered by the specified keys. """ """ Verify the endpoint returns a list of course runs filtered by the specified keys. """
course_runs = CourseRunFactory.create_batch(3) course_runs = CourseRunFactory.create_batch(3)
...@@ -72,9 +81,10 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -72,9 +81,10 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
self.assertListEqual(response.data['results'], self.serialize_course_run(course_runs, many=True)) self.assertListEqual(response.data['results'], self.serialize_course_run(course_runs, many=True))
def test_contains_single_course_run(self): def test_contains_single_course_run(self):
""" Verify that a single course_run is contained in a query """
qs = urllib.parse.urlencode({ qs = urllib.parse.urlencode({
'query': 'id:course*', 'query': 'id:course*',
'course_run_ids': self.course_run.key 'course_run_ids': self.course_run.key,
}) })
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs) url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
...@@ -89,6 +99,18 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -89,6 +99,18 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
} }
) )
def test_contains_single_course_run_invalid_partner(self):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs = urllib.parse.urlencode({
'query': 'id:course*',
'course_run_ids': self.course_run.key,
'partner': 'foo'
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def test_contains_multiple_course_runs(self): def test_contains_multiple_course_runs(self):
qs = urllib.parse.urlencode({ qs = urllib.parse.urlencode({
'query': 'id:course*', 'query': 'id:course*',
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
from io import StringIO from io import StringIO
import pytz import pytz
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.management import call_command from django.core.management import call_command
from django.db import transaction from django.db import transaction
...@@ -25,12 +26,13 @@ from rest_framework.response import Response ...@@ -25,12 +26,13 @@ from rest_framework.response import Response
from course_discovery.apps.api import filters from course_discovery.apps.api import filters
from course_discovery.apps.api import serializers from course_discovery.apps.api import serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.pagination import PageNumberPagination from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer, CourseRunCSVRenderer from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer, CourseRunCSVRenderer
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.utils import SearchQuerySetWrapper from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import Course, CourseRun, Seat, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Partner, Program, Seat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
...@@ -225,10 +227,27 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -225,10 +227,27 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = serializers.CourseRunSerializer serializer_class = serializers.CourseRunSerializer
def _get_partner_name_from_code(self):
""" Return the partner name associated with a partner code or the default partner """
partner = None
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner')
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner.name
def get_queryset(self): def get_queryset(self):
q = self.request.query_params.get('q', None) q = self.request.query_params.get('q', None)
partner_name = self._get_partner_name_from_code()
if q: if q:
qs = SearchQuerySetWrapper(CourseRun.search(q)) qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner_name))
# This is necessary to avoid issues with the filter backend. # This is necessary to avoid issues with the filter backend.
qs.model = self.queryset.model qs.model = self.queryset.model
return qs return qs
...@@ -251,6 +270,12 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -251,6 +270,12 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
type: string type: string
paramType: query paramType: query
multiple: false multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
""" """
return super(CourseRunViewSet, self).list(request, *args, **kwargs) return super(CourseRunViewSet, self).list(request, *args, **kwargs)
...@@ -280,13 +305,21 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -280,13 +305,21 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
type: string type: string
paramType: query paramType: query
multiple: true multiple: true
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
""" """
query = request.GET.get('query') query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids') course_run_ids = request.GET.get('course_run_ids')
partner_name = self._get_partner_name_from_code()
if query and course_run_ids: if query and course_run_ids:
course_run_ids = course_run_ids.split(',') course_run_ids = course_run_ids.split(',')
course_runs = CourseRun.search(query).filter(key__in=course_run_ids).values_list('key', flat=True) course_runs = CourseRun.search(query).filter(partner=partner_name).filter(key__in=course_run_ids).\
values_list('key', flat=True)
contains = {course_run_id: course_run_id in course_runs for course_run_id in course_run_ids} contains = {course_run_id: course_run_id in course_runs for course_run_id in course_run_ids}
instance = {'course_runs': contains} instance = {'course_runs': contains}
......
...@@ -40,6 +40,7 @@ class BaseCourseIndex(OrganizationsMixin, BaseIndex): ...@@ -40,6 +40,7 @@ class BaseCourseIndex(OrganizationsMixin, BaseIndex):
subjects = indexes.MultiValueField(faceted=True) subjects = indexes.MultiValueField(faceted=True)
organizations = indexes.MultiValueField(faceted=True) organizations = indexes.MultiValueField(faceted=True)
level_type = indexes.CharField(model_attr='level_type__name', null=True, faceted=True) level_type = indexes.CharField(model_attr='level_type__name', null=True, faceted=True)
partner = indexes.CharField(model_attr='partner__name', null=True, faceted=True)
def prepare_subjects(self, obj): def prepare_subjects(self, obj):
return [subject.name for subject in obj.subjects.all()] return [subject.name for subject in obj.subjects.all()]
...@@ -83,6 +84,7 @@ class CourseRunIndex(BaseCourseIndex, indexes.Indexable): ...@@ -83,6 +84,7 @@ class CourseRunIndex(BaseCourseIndex, indexes.Indexable):
seat_types = indexes.MultiValueField(model_attr='seat_types', null=True, faceted=True) seat_types = indexes.MultiValueField(model_attr='seat_types', null=True, faceted=True)
type = indexes.CharField(model_attr='type', null=True, faceted=True) type = indexes.CharField(model_attr='type', null=True, faceted=True)
image_url = indexes.CharField(model_attr='image_url', null=True) image_url = indexes.CharField(model_attr='image_url', null=True)
partner = indexes.CharField(model_attr='course__partner__name', null=True, faceted=True)
def _prepare_language(self, language): def _prepare_language(self, language):
return language.macrolanguage return language.macrolanguage
...@@ -115,6 +117,7 @@ class ProgramIndex(OrganizationsMixin, BaseIndex, indexes.Indexable): ...@@ -115,6 +117,7 @@ class ProgramIndex(OrganizationsMixin, BaseIndex, indexes.Indexable):
organizations = indexes.MultiValueField(faceted=True) organizations = indexes.MultiValueField(faceted=True)
image_url = indexes.CharField(model_attr='image_url', null=True) image_url = indexes.CharField(model_attr='image_url', null=True)
status = indexes.CharField(model_attr='status', faceted=True) status = indexes.CharField(model_attr='status', faceted=True)
partner = indexes.CharField(model_attr='partner__name', null=True, faceted=True)
def prepare_marketing_url(self, obj): def prepare_marketing_url(self, obj):
return obj.marketing_url return obj.marketing_url
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment