Commit a9a97c5e by Clinton Blackburn Committed by GitHub

Merge pull request #168 from edx/clintonb/api-updates

Added support for filtering courses and course runs by key
parents 4c8415e8 3f4401d5
import django_filters
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db.models import QuerySet
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from drf_haystack.filters import HaystackFacetFilter from drf_haystack.filters import HaystackFacetFilter
from drf_haystack.query import FacetQueryBuilder from drf_haystack.query import FacetQueryBuilder
...@@ -6,6 +8,8 @@ from dry_rest_permissions.generics import DRYPermissionFiltersBase ...@@ -6,6 +8,8 @@ from dry_rest_permissions.generics import DRYPermissionFiltersBase
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.exceptions import PermissionDenied, NotFound from rest_framework.exceptions import PermissionDenied, NotFound
from course_discovery.apps.course_metadata.models import Course, CourseRun
User = get_user_model() User = get_user_model()
...@@ -43,8 +47,8 @@ class PermissionsFilter(DRYPermissionFiltersBase): ...@@ -43,8 +47,8 @@ class PermissionsFilter(DRYPermissionFiltersBase):
class FacetQueryBuilderWithQueries(FacetQueryBuilder): class FacetQueryBuilderWithQueries(FacetQueryBuilder):
def build_query(self, **filters): def build_query(self, **query_filters):
query = super(FacetQueryBuilderWithQueries, self).build_query(**filters) query = super(FacetQueryBuilderWithQueries, self).build_query(**query_filters)
facet_serializer_cls = self.view.get_facet_serializer_class() facet_serializer_cls = self.view.get_facet_serializer_class()
query['query_facets'] = getattr(facet_serializer_cls.Meta, 'field_queries', {}) query['query_facets'] = getattr(facet_serializer_cls.Meta, 'field_queries', {})
return query return query
...@@ -52,3 +56,36 @@ class FacetQueryBuilderWithQueries(FacetQueryBuilder): ...@@ -52,3 +56,36 @@ class FacetQueryBuilderWithQueries(FacetQueryBuilder):
class HaystackFacetFilterWithQueries(HaystackFacetFilter): class HaystackFacetFilterWithQueries(HaystackFacetFilter):
query_builder_class = FacetQueryBuilderWithQueries query_builder_class = FacetQueryBuilderWithQueries
class CharListFilter(django_filters.CharFilter):
def filter(self, qs, value): # pylint: disable=method-hidden
if value not in (None, ''):
value = value.split(',')
return super(CharListFilter, self).filter(qs, value)
class CourseFilter(django_filters.FilterSet):
keys = CharListFilter(name='key', lookup_type='in')
class Meta:
model = Course
fields = ['keys']
class CourseRunFilter(django_filters.FilterSet):
keys = CharListFilter(name='key', lookup_type='in')
@property
def qs(self):
# This endpoint supports query via Haystack. If that form of filtering is active,
# do not attempt to treat the queryset as a normal Django queryset.
if not isinstance(self.queryset, QuerySet):
return self.queryset
return super(CourseRunFilter, self).qs
class Meta:
model = CourseRun
fields = ['keys']
# pylint: disable=no-member # pylint: disable=no-member
import urllib import urllib
import ddt
import ddt
from django.db.models.functions import Lower from django.db.models.functions import Lower
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APITestCase, APIRequestFactory from rest_framework.test import APITestCase, APIRequestFactory
...@@ -9,8 +9,8 @@ from rest_framework.test import APITestCase, APIRequestFactory ...@@ -9,8 +9,8 @@ from rest_framework.test import APITestCase, APIRequestFactory
from course_discovery.apps.api.serializers import CourseRunSerializer from course_discovery.apps.api.serializers import CourseRunSerializer
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory
@ddt.ddt @ddt.ddt
...@@ -25,13 +25,16 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -25,13 +25,16 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
self.request = APIRequestFactory().get('/') self.request = APIRequestFactory().get('/')
self.request.user = self.user self.request.user = self.user
def serialize_course_run(self, course_run, **kwargs):
return CourseRunSerializer(course_run, context={'request': self.request}, **kwargs).data
def test_get(self): def test_get(self):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key}) url = reverse('api:v1:course_run-detail', kwargs={'key': self.course_run.key})
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, CourseRunSerializer(self.course_run, context={'request': self.request}).data) self.assertEqual(response.data, self.serialize_course_run(self.course_run))
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all catalogs. """ """ Verify the endpoint returns a list of all catalogs. """
...@@ -41,11 +44,7 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -41,11 +44,7 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual( self.assertListEqual(
response.data['results'], response.data['results'],
CourseRunSerializer( self.serialize_course_run(CourseRun.objects.all().order_by(Lower('key')), many=True)
CourseRun.objects.all().order_by(Lower('key')),
many=True,
context={'request': self.request}
).data
) )
def test_list_query(self): def test_list_query(self):
...@@ -58,16 +57,20 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -58,16 +57,20 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
response = self.client.get(url) response = self.client.get(url)
actual_sorted = sorted(response.data['results'], key=lambda course_run: course_run['key']) actual_sorted = sorted(response.data['results'], key=lambda course_run: course_run['key'])
expected_sorted = sorted( expected_sorted = sorted(self.serialize_course_run(course_runs, many=True),
CourseRunSerializer( key=lambda course_run: course_run['key'])
course_runs,
many=True,
context={'request': self.request}
).data,
key=lambda course_run: course_run['key']
)
self.assertListEqual(actual_sorted, expected_sorted) self.assertListEqual(actual_sorted, expected_sorted)
def test_list_key_filter(self):
""" Verify the endpoint returns a list of course runs filtered by the specified keys. """
course_runs = CourseRunFactory.create_batch(3)
course_runs = sorted(course_runs, key=lambda course: course.key.lower())
keys = ','.join([course.key for course in course_runs])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course_run-list'), keys=keys)
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course_run(course_runs, many=True))
def test_contains_single_course_run(self): def test_contains_single_course_run(self):
qs = urllib.parse.urlencode({ qs = urllib.parse.urlencode({
'query': 'id:course*', 'query': 'id:course*',
......
...@@ -44,3 +44,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -44,3 +44,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
def test_list_key_filter(self):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
courses = CourseFactory.create_batch(3)
courses = sorted(courses, key=lambda course: course.key.lower())
keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
...@@ -19,11 +19,12 @@ from edx_rest_framework_extensions.permissions import IsSuperuser ...@@ -19,11 +19,12 @@ from edx_rest_framework_extensions.permissions import IsSuperuser
from rest_framework import status, viewsets from rest_framework import status, viewsets
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.exceptions import PermissionDenied, ParseError from rest_framework.exceptions import PermissionDenied, ParseError
from rest_framework.filters import DjangoFilterBackend
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from course_discovery.apps.api import filters
from course_discovery.apps.api import serializers from course_discovery.apps.api import serializers
from course_discovery.apps.api.filters import PermissionsFilter, HaystackFacetFilterWithQueries
from course_discovery.apps.api.pagination import PageNumberPagination from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer, CourseRunCSVRenderer from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer, CourseRunCSVRenderer
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
...@@ -39,7 +40,7 @@ User = get_user_model() ...@@ -39,7 +40,7 @@ User = get_user_model()
class CatalogViewSet(viewsets.ModelViewSet): class CatalogViewSet(viewsets.ModelViewSet):
""" Catalog resource. """ """ Catalog resource. """
filter_backends = (PermissionsFilter,) filter_backends = (filters.PermissionsFilter,)
lookup_field = 'id' lookup_field = 'id'
permission_classes = (DRYPermissions,) permission_classes = (DRYPermissions,)
queryset = Catalog.objects.all() queryset = Catalog.objects.all()
...@@ -172,6 +173,8 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -172,6 +173,8 @@ class CatalogViewSet(viewsets.ModelViewSet):
class CourseViewSet(viewsets.ReadOnlyModelViewSet): class CourseViewSet(viewsets.ReadOnlyModelViewSet):
""" Course resource. """ """ Course resource. """
filter_backends = (DjangoFilterBackend,)
filter_class = filters.CourseFilter
lookup_field = 'key' lookup_field = 'key'
lookup_value_regex = COURSE_ID_REGEX lookup_value_regex = COURSE_ID_REGEX
queryset = Course.objects.all() queryset = Course.objects.all()
...@@ -190,10 +193,16 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -190,10 +193,16 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
""" List all courses. """ List all courses.
--- ---
parameters: parameters:
- name: q - name: q
description: Elasticsearch querystring query description: Elasticsearch querystring query. This filter takes precedence over other filters.
required: false
type: string
paramType: query
multiple: false
- name: keys
description: Filter by keys (comma-separated list)
required: false required: false
type: string type: string
paramType: query paramType: query
...@@ -208,6 +217,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -208,6 +217,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
""" CourseRun resource. """ """ CourseRun resource. """
filter_backends = (DjangoFilterBackend,)
filter_class = filters.CourseRunFilter
lookup_field = 'key' lookup_field = 'key'
lookup_value_regex = COURSE_RUN_ID_REGEX lookup_value_regex = COURSE_RUN_ID_REGEX
queryset = CourseRun.objects.all().order_by(Lower('key')) queryset = CourseRun.objects.all().order_by(Lower('key'))
...@@ -217,7 +228,10 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -217,7 +228,10 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
def get_queryset(self): def get_queryset(self):
q = self.request.query_params.get('q', None) q = self.request.query_params.get('q', None)
if q: if q:
return SearchQuerySetWrapper(CourseRun.search(q)) qs = SearchQuerySetWrapper(CourseRun.search(q))
# This is necessary to avoid issues with the filter backend.
qs.model = self.queryset.model
return qs
else: else:
return super(CourseRunViewSet, self).get_queryset() return super(CourseRunViewSet, self).get_queryset()
...@@ -226,7 +240,13 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -226,7 +240,13 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
--- ---
parameters: parameters:
- name: q - name: q
description: Elasticsearch querystring query description: Elasticsearch querystring query. This filter takes precedence over other filters.
required: false
type: string
paramType: query
multiple: false
- name: keys
description: Filter by keys (comma-separated list)
required: false required: false
type: string type: string
paramType: query paramType: query
...@@ -371,7 +391,7 @@ class AffiliateWindowViewSet(viewsets.ViewSet): ...@@ -371,7 +391,7 @@ class AffiliateWindowViewSet(viewsets.ViewSet):
class BaseHaystackViewSet(FacetMixin, HaystackViewSet): class BaseHaystackViewSet(FacetMixin, HaystackViewSet):
document_uid_field = 'key' document_uid_field = 'key'
facet_filter_backends = [HaystackFacetFilterWithQueries, HaystackFilter] facet_filter_backends = [filters.HaystackFacetFilterWithQueries, HaystackFilter]
load_all = True load_all = True
lookup_field = 'key' lookup_field = 'key'
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
......
...@@ -41,6 +41,7 @@ THIRD_PARTY_APPS = ( ...@@ -41,6 +41,7 @@ THIRD_PARTY_APPS = (
'guardian', 'guardian',
'dry_rest_permissions', 'dry_rest_permissions',
'compressor', 'compressor',
'django_filters',
) )
PROJECT_APPS = ( PROJECT_APPS = (
......
cryptography==1.4 cryptography==1.4
django==1.8.14 django==1.8.14
django-extensions==1.6.7 django-extensions==1.6.7
django-filter==0.13.0
django-guardian==1.4.4 django-guardian==1.4.4
django-haystack==2.4.1 django-haystack==2.4.1
django-libsass==0.7 django-libsass==0.7
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment