Commit 2e4b3726 by Michael Frey Committed by GitHub

Merge pull request #199 from edx/mjfrey/SOL-1924

Add partner filtering for course runs
parents a36feacf 48644be4
from rest_framework import status
from rest_framework.exceptions import APIException
class InvalidPartnerError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
......@@ -10,7 +10,7 @@ from course_discovery.apps.api.serializers import CourseRunSerializer
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, PartnerFactory
@ddt.ddt
......@@ -19,8 +19,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
super(CourseRunViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user)
self.course_run = CourseRunFactory()
self.course_run_2 = CourseRunFactory()
self.default_partner = PartnerFactory()
self.course_run = CourseRunFactory(course__partner=self.default_partner)
self.course_run_2 = CourseRunFactory(course__partner=self.default_partner)
self.refresh_index()
self.request = APIRequestFactory().get('/')
self.request.user = self.user
......@@ -49,10 +50,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """
title = 'Some random title'
course_runs = CourseRunFactory.create_batch(3, title=title)
course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.default_partner)
CourseRunFactory(title='non-matching name')
query = 'title:' + title
query = 'title:Some random title'
url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query)
response = self.client.get(url)
......@@ -61,6 +61,15 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted)
def test_list_query_invalid_partner(self):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query = 'title:Some random title'
url = '{root}?q={query}&partner={partner}'.format(root=reverse('api:v1:course_run-list'), query=query,
partner='foo')
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def test_list_key_filter(self):
""" Verify the endpoint returns a list of course runs filtered by the specified keys. """
course_runs = CourseRunFactory.create_batch(3)
......@@ -72,9 +81,10 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
self.assertListEqual(response.data['results'], self.serialize_course_run(course_runs, many=True))
def test_contains_single_course_run(self):
""" Verify that a single course_run is contained in a query """
qs = urllib.parse.urlencode({
'query': 'id:course*',
'course_run_ids': self.course_run.key
'course_run_ids': self.course_run.key,
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
......@@ -89,6 +99,18 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
}
)
def test_contains_single_course_run_invalid_partner(self):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs = urllib.parse.urlencode({
'query': 'id:course*',
'course_run_ids': self.course_run.key,
'partner': 'foo'
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def test_contains_multiple_course_runs(self):
qs = urllib.parse.urlencode({
'query': 'id:course*',
......
......@@ -4,6 +4,7 @@ import os
from io import StringIO
import pytz
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.management import call_command
from django.db import transaction
......@@ -25,12 +26,13 @@ from rest_framework.response import Response
from course_discovery.apps.api import filters
from course_discovery.apps.api import serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer, CourseRunCSVRenderer
from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import Course, CourseRun, Seat, Program
from course_discovery.apps.course_metadata.models import Course, CourseRun, Partner, Program, Seat
logger = logging.getLogger(__name__)
User = get_user_model()
......@@ -225,10 +227,27 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
permission_classes = (IsAuthenticated,)
serializer_class = serializers.CourseRunSerializer
def _get_partner_name_from_code(self):
""" Return the partner name associated with a partner code or the default partner """
partner = None
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner')
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner.name
def get_queryset(self):
q = self.request.query_params.get('q', None)
partner_name = self._get_partner_name_from_code()
if q:
qs = SearchQuerySetWrapper(CourseRun.search(q))
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner_name))
# This is necessary to avoid issues with the filter backend.
qs.model = self.queryset.model
return qs
......@@ -251,6 +270,12 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
type: string
paramType: query
multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
"""
return super(CourseRunViewSet, self).list(request, *args, **kwargs)
......@@ -280,13 +305,21 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
type: string
paramType: query
multiple: true
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids')
partner_name = self._get_partner_name_from_code()
if query and course_run_ids:
course_run_ids = course_run_ids.split(',')
course_runs = CourseRun.search(query).filter(key__in=course_run_ids).values_list('key', flat=True)
course_runs = CourseRun.search(query).filter(partner=partner_name).filter(key__in=course_run_ids).\
values_list('key', flat=True)
contains = {course_run_id: course_run_id in course_runs for course_run_id in course_run_ids}
instance = {'course_runs': contains}
......
......@@ -40,6 +40,7 @@ class BaseCourseIndex(OrganizationsMixin, BaseIndex):
subjects = indexes.MultiValueField(faceted=True)
organizations = indexes.MultiValueField(faceted=True)
level_type = indexes.CharField(model_attr='level_type__name', null=True, faceted=True)
partner = indexes.CharField(model_attr='partner__name', null=True, faceted=True)
def prepare_subjects(self, obj):
return [subject.name for subject in obj.subjects.all()]
......@@ -83,6 +84,7 @@ class CourseRunIndex(BaseCourseIndex, indexes.Indexable):
seat_types = indexes.MultiValueField(model_attr='seat_types', null=True, faceted=True)
type = indexes.CharField(model_attr='type', null=True, faceted=True)
image_url = indexes.CharField(model_attr='image_url', null=True)
partner = indexes.CharField(model_attr='course__partner__name', null=True, faceted=True)
def _prepare_language(self, language):
return language.macrolanguage
......@@ -115,6 +117,7 @@ class ProgramIndex(OrganizationsMixin, BaseIndex, indexes.Indexable):
organizations = indexes.MultiValueField(faceted=True)
image_url = indexes.CharField(model_attr='image_url', null=True)
status = indexes.CharField(model_attr='status', faceted=True)
partner = indexes.CharField(model_attr='partner__name', null=True, faceted=True)
def prepare_marketing_url(self, obj):
return obj.marketing_url
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment