Commit 695ec99a by Clinton Blackburn Committed by GitHub

Updated search endpoint to filter by default partner (#297)

If no partner is specified on the request, all results are filtered so that only items associated with the default partner are returned.

ECOM-5459
parent 657e2204
import logging import logging
import django_filters import django_filters
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db.models import QuerySet from django.db.models import QuerySet
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from drf_haystack.filters import HaystackFacetFilter from drf_haystack.filters import HaystackFacetFilter, HaystackFilter as DefaultHaystackFilter
from drf_haystack.query import FacetQueryBuilder from drf_haystack.query import FacetQueryBuilder
from dry_rest_permissions.generics import DRYPermissionFiltersBase from dry_rest_permissions.generics import DRYPermissionFiltersBase
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.exceptions import PermissionDenied, NotFound from rest_framework.exceptions import PermissionDenied, NotFound
from course_discovery.apps.core.models import Partner
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -61,6 +63,18 @@ class HaystackFacetFilterWithQueries(HaystackFacetFilter): ...@@ -61,6 +63,18 @@ class HaystackFacetFilterWithQueries(HaystackFacetFilter):
query_builder_class = FacetQueryBuilderWithQueries query_builder_class = FacetQueryBuilderWithQueries
class HaystackFilter(DefaultHaystackFilter):
@staticmethod
def get_request_filters(request):
filters = HaystackFacetFilter.get_request_filters(request)
# Return data for the default partner, if no partner is requested
if not any(field in filters for field in ('partner', 'partner_exact')):
filters['partner'] = Partner.objects.get(pk=settings.DEFAULT_PARTNER_ID).short_code
return filters
class CharListFilter(django_filters.CharFilter): class CharListFilter(django_filters.CharFilter):
def filter(self, qs, value): # pylint: disable=method-hidden def filter(self, qs, value): # pylint: disable=method-hidden
if value not in (None, ''): if value not in (None, ''):
......
...@@ -45,7 +45,7 @@ COURSE_RUN_FACET_FIELD_QUERIES = { ...@@ -45,7 +45,7 @@ COURSE_RUN_FACET_FIELD_QUERIES = {
COURSE_RUN_SEARCH_FIELDS = ( COURSE_RUN_SEARCH_FIELDS = (
'text', 'key', 'title', 'short_description', 'full_description', 'start', 'end', 'enrollment_start', 'text', 'key', 'title', 'short_description', 'full_description', 'start', 'end', 'enrollment_start',
'enrollment_end', 'pacing_type', 'language', 'transcript_languages', 'marketing_url', 'content_type', 'org', 'enrollment_end', 'pacing_type', 'language', 'transcript_languages', 'marketing_url', 'content_type', 'org',
'number', 'seat_types', 'image_url', 'type', 'level_type', 'availability', 'published', 'number', 'seat_types', 'image_url', 'type', 'level_type', 'availability', 'published', 'partner',
) )
PROGRAM_FACET_FIELD_OPTIONS = { PROGRAM_FACET_FIELD_OPTIONS = {
...@@ -57,7 +57,7 @@ PROGRAM_FACET_FIELD_OPTIONS = { ...@@ -57,7 +57,7 @@ PROGRAM_FACET_FIELD_OPTIONS = {
BASE_PROGRAM_FIELDS = ( BASE_PROGRAM_FIELDS = (
'text', 'uuid', 'title', 'subtitle', 'type', 'marketing_url', 'content_type', 'status', 'card_image_url', 'text', 'uuid', 'title', 'subtitle', 'type', 'marketing_url', 'content_type', 'status', 'card_image_url',
'published', 'published', 'partner',
) )
PROGRAM_SEARCH_FIELDS = BASE_PROGRAM_FIELDS + ('authoring_organizations',) PROGRAM_SEARCH_FIELDS = BASE_PROGRAM_FIELDS + ('authoring_organizations',)
......
...@@ -221,7 +221,6 @@ class ProgramCourseSerializerTests(TestCase): ...@@ -221,7 +221,6 @@ class ProgramCourseSerializerTests(TestCase):
class ProgramSerializerTests(TestCase): class ProgramSerializerTests(TestCase):
def test_data(self): def test_data(self):
request = make_request() request = make_request()
org_list = OrganizationFactory.create_batch(1) org_list = OrganizationFactory.create_batch(1)
...@@ -579,7 +578,6 @@ class AffiliateWindowSerializerTests(TestCase): ...@@ -579,7 +578,6 @@ class AffiliateWindowSerializerTests(TestCase):
class CourseRunSearchSerializerTests(TestCase): class CourseRunSearchSerializerTests(TestCase):
def test_data(self): def test_data(self):
course_run = CourseRunFactory() course_run = CourseRunFactory()
serializer = self.serialize_course_run(course_run) serializer = self.serialize_course_run(course_run)
...@@ -607,6 +605,7 @@ class CourseRunSearchSerializerTests(TestCase): ...@@ -607,6 +605,7 @@ class CourseRunSearchSerializerTests(TestCase):
'level_type': course_run.level_type.name, 'level_type': course_run.level_type.name,
'availability': course_run.availability, 'availability': course_run.availability,
'published': course_run.status == CourseRun.Status.Published, 'published': course_run.status == CourseRun.Status.Published,
'partner': course_run.course.partner.short_code,
} }
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
...@@ -624,52 +623,41 @@ class CourseRunSearchSerializerTests(TestCase): ...@@ -624,52 +623,41 @@ class CourseRunSearchSerializerTests(TestCase):
class ProgramSearchSerializerTests(TestCase): class ProgramSearchSerializerTests(TestCase):
def test_data(self): def _create_expected_data(self, program):
program = ProgramFactory() return {
authoring_organization, crediting_organization = OrganizationFactory.create_batch(2)
program.authoring_organizations.add(authoring_organization)
program.credit_backing_organizations.add(crediting_organization)
program.save()
expected_authoring_organizations = [
OrganizationSerializer(authoring_organization).data,
]
# NOTE: This serializer expects SearchQuerySet results, so we run a search on the newly-created object
# to generate such a result.
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
serializer = ProgramSearchSerializer(result)
expected = {
'uuid': str(program.uuid), 'uuid': str(program.uuid),
'title': program.title, 'title': program.title,
'subtitle': program.subtitle, 'subtitle': program.subtitle,
'type': program.type.name, 'type': program.type.name,
'marketing_url': program.marketing_url, 'marketing_url': program.marketing_url,
'authoring_organizations': expected_authoring_organizations, 'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'content_type': 'program', 'content_type': 'program',
'card_image_url': program.card_image_url, 'card_image_url': program.card_image_url,
'status': program.status, 'status': program.status,
'published': program.status == Program.Status.Active, 'published': program.status == Program.Status.Active,
'partner': program.partner.short_code,
} }
def test_data(self):
authoring_organization, crediting_organization = OrganizationFactory.create_batch(2)
program = ProgramFactory(authoring_organizations=[authoring_organization],
credit_backing_organizations=[crediting_organization])
# NOTE: This serializer expects SearchQuerySet results, so we run a search on the newly-created object
# to generate such a result.
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
serializer = ProgramSearchSerializer(result)
expected = self._create_expected_data(program)
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
def test_organization_bodies_missing(self): def test_data_without_organizations(self):
program = ProgramFactory() """ Verify the serializer serialized programs with no associated organizations.
In such cases the organizations value should be an empty array. """
program = ProgramFactory(authoring_organizations=[], credit_backing_organizations=[])
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0] result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
result.organization_bodies = None
serializer = ProgramSearchSerializer(result) serializer = ProgramSearchSerializer(result)
expected = { expected = self._create_expected_data(program)
'uuid': str(program.uuid),
'title': program.title,
'subtitle': program.subtitle,
'type': program.type.name,
'marketing_url': program.marketing_url,
'authoring_organizations': [],
'content_type': 'program',
'card_image_url': program.card_image_url,
'status': program.status,
'published': program.status == Program.Status.Active,
}
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
...@@ -3,12 +3,13 @@ import json ...@@ -3,12 +3,13 @@ import json
import urllib.parse import urllib.parse
import ddt import ddt
from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import CourseRunSearchSerializer, ProgramSearchSerializer from course_discovery.apps.api.serializers import CourseRunSearchSerializer, ProgramSearchSerializer
from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD, PartnerFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun, Program from course_discovery.apps.course_metadata.models import CourseRun, Program
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory
...@@ -31,8 +32,15 @@ class LoginMixin: ...@@ -31,8 +32,15 @@ class LoginMixin:
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
class DefaultPartnerMixin:
def setUp(self):
super(DefaultPartnerMixin, self).setUp()
self.partner = PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
@ddt.ddt @ddt.ddt
class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase): class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
""" Tests for CourseRunSearchViewSet. """ """ Tests for CourseRunSearchViewSet. """
faceted_path = reverse('api:v1:search-course_runs-facets') faceted_path = reverse('api:v1:search-course_runs-facets')
list_path = reverse('api:v1:search-course_runs-list') list_path = reverse('api:v1:search-course_runs-list')
...@@ -73,7 +81,8 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT ...@@ -73,7 +81,8 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
""" Asserts the search functionality returns results for a generated query. """ """ Asserts the search functionality returns results for a generated query. """
# Generate data that should be indexed and returned by the query # Generate data that should be indexed and returned by the query
course_run = CourseRunFactory(course__title='Software Testing', status=CourseRun.Status.Published) course_run = CourseRunFactory(course__partner=self.partner, course__title='Software Testing',
status=CourseRun.Status.Published)
response = self.get_search_response('software', faceted=faceted) response = self.get_search_response('software', faceted=faceted)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -109,14 +118,14 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT ...@@ -109,14 +118,14 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
def test_availability_faceting(self): def test_availability_faceting(self):
""" Verify the endpoint returns availability facets with the results. """ """ Verify the endpoint returns availability facets with the results. """
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
archived = CourseRunFactory(start=now - datetime.timedelta(weeks=2), end=now - datetime.timedelta(weeks=1), archived = CourseRunFactory(course__partner=self.partner, start=now - datetime.timedelta(weeks=2),
status=CourseRun.Status.Published) end=now - datetime.timedelta(weeks=1), status=CourseRun.Status.Published)
current = CourseRunFactory(start=now - datetime.timedelta(weeks=2), end=now + datetime.timedelta(weeks=1), current = CourseRunFactory(course__partner=self.partner, start=now - datetime.timedelta(weeks=2),
status=CourseRun.Status.Published) end=now + datetime.timedelta(weeks=1), status=CourseRun.Status.Published)
starting_soon = CourseRunFactory(start=now + datetime.timedelta(days=10), end=now + datetime.timedelta(days=90), starting_soon = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(days=10),
status=CourseRun.Status.Published) end=now + datetime.timedelta(days=90), status=CourseRun.Status.Published)
upcoming = CourseRunFactory(start=now + datetime.timedelta(days=61), end=now + datetime.timedelta(days=90), upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(days=61),
status=CourseRun.Status.Published) end=now + datetime.timedelta(days=90), status=CourseRun.Status.Published)
response = self.get_search_response(faceted=True) response = self.get_search_response(faceted=True)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -162,14 +171,12 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT ...@@ -162,14 +171,12 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
self.assertDictContainsSubset(expected, response_data['queries']) self.assertDictContainsSubset(expected, response_data['queries'])
class AggregateSearchViewSet(SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase): class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets') path = reverse('api:v1:search-all-facets')
def get_search_response(self, query=None): def get_search_response(self, querystring=None):
qs = '' querystring = querystring or {}
qs = urllib.parse.urlencode(querystring)
if query:
qs = urllib.parse.urlencode({'q': query})
url = '{path}?{qs}'.format(path=self.path, qs=qs) url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url) return self.client.get(url)
...@@ -177,14 +184,40 @@ class AggregateSearchViewSet(SerializationMixin, LoginMixin, ElasticsearchTestMi ...@@ -177,14 +184,40 @@ class AggregateSearchViewSet(SerializationMixin, LoginMixin, ElasticsearchTestMi
def test_results_only_include_published_objects(self): def test_results_only_include_published_objects(self):
""" Verify the search results only include items with status set to 'Published'. """ """ Verify the search results only include items with status set to 'Published'. """
# These items should NOT be in the results # These items should NOT be in the results
CourseRunFactory(status=CourseRun.Status.Unpublished) CourseRunFactory(course__partner=self.partner, status=CourseRun.Status.Unpublished)
ProgramFactory(status=Program.Status.Unpublished) ProgramFactory(partner=self.partner, status=Program.Status.Unpublished)
course_run = CourseRunFactory(status=CourseRun.Status.Published) course_run = CourseRunFactory(course__partner=self.partner, status=CourseRun.Status.Published)
program = ProgramFactory(status=Program.Status.Active) program = ProgramFactory(partner=self.partner, status=Program.Status.Active)
response = self.get_search_response() response = self.get_search_response()
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'], self.assertListEqual(response_data['objects']['results'],
[self.serialize_course_run(course_run), self.serialize_program(program)]) [self.serialize_course_run(course_run), self.serialize_program(program)])
def test_results_filtered_by_default_partner(self):
""" Verify the search results only include items related to the default partner if no partner is
specified on the request. If a partner is included, the data should be filtered to the requested partner. """
course_run = CourseRunFactory(course__partner=self.partner, status=CourseRun.Status.Published)
program = ProgramFactory(partner=self.partner, status=Program.Status.Active)
# This data should NOT be in the results
other_partner = PartnerFactory()
other_course_run = CourseRunFactory(course__partner=other_partner, status=CourseRun.Status.Published)
other_program = ProgramFactory(partner=other_partner, status=Program.Status.Active)
self.assertNotEqual(other_program.partner.short_code, self.partner.short_code)
self.assertNotEqual(other_course_run.course.partner.short_code, self.partner.short_code)
response = self.get_search_response()
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'],
[self.serialize_program(program), self.serialize_course_run(course_run)])
# Filter results by partner
response = self.get_search_response({'partner': other_partner.short_code})
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'],
[self.serialize_course_run(other_course_run), self.serialize_program(other_program)])
...@@ -12,7 +12,6 @@ from django.db.models import Q ...@@ -12,7 +12,6 @@ from django.db.models import Q
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.http import HttpResponse from django.http import HttpResponse
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from drf_haystack.filters import HaystackFilter
from drf_haystack.mixins import FacetMixin from drf_haystack.mixins import FacetMixin
from drf_haystack.viewsets import HaystackViewSet from drf_haystack.viewsets import HaystackViewSet
from dry_rest_permissions.generics import DRYPermissions from dry_rest_permissions.generics import DRYPermissions
...@@ -455,7 +454,7 @@ class AffiliateWindowViewSet(viewsets.ViewSet): ...@@ -455,7 +454,7 @@ class AffiliateWindowViewSet(viewsets.ViewSet):
class BaseHaystackViewSet(FacetMixin, HaystackViewSet): class BaseHaystackViewSet(FacetMixin, HaystackViewSet):
document_uid_field = 'key' document_uid_field = 'key'
facet_filter_backends = [filters.HaystackFacetFilterWithQueries, HaystackFilter] facet_filter_backends = [filters.HaystackFacetFilterWithQueries, filters.HaystackFilter]
load_all = True load_all = True
lookup_field = 'key' lookup_field = 'key'
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
...@@ -508,7 +507,7 @@ class BaseHaystackViewSet(FacetMixin, HaystackViewSet): ...@@ -508,7 +507,7 @@ class BaseHaystackViewSet(FacetMixin, HaystackViewSet):
return super(BaseHaystackViewSet, self).facets(request) return super(BaseHaystackViewSet, self).facets(request)
def filter_facet_queryset(self, queryset): def filter_facet_queryset(self, queryset):
queryset = super(BaseHaystackViewSet, self).filter_facet_queryset(queryset) queryset = super().filter_facet_queryset(queryset)
facet_serializer_cls = self.get_facet_serializer_class() facet_serializer_cls = self.get_facet_serializer_class()
field_queries = getattr(facet_serializer_cls.Meta, 'field_queries', {}) field_queries = getattr(facet_serializer_cls.Meta, 'field_queries', {})
......
...@@ -111,9 +111,12 @@ class CourseRunIndex(BaseCourseIndex, indexes.Indexable): ...@@ -111,9 +111,12 @@ class CourseRunIndex(BaseCourseIndex, indexes.Indexable):
seat_types = indexes.MultiValueField(model_attr='seat_types', null=True, faceted=True) seat_types = indexes.MultiValueField(model_attr='seat_types', null=True, faceted=True)
type = indexes.CharField(model_attr='type', null=True, faceted=True) type = indexes.CharField(model_attr='type', null=True, faceted=True)
image_url = indexes.CharField(model_attr='card_image_url', null=True) image_url = indexes.CharField(model_attr='card_image_url', null=True)
partner = indexes.CharField(model_attr='course__partner__short_code', null=True, faceted=True) partner = indexes.CharField(null=True, faceted=True)
published = indexes.BooleanField(null=False, faceted=True) published = indexes.BooleanField(null=False, faceted=True)
def prepare_partner(self, obj):
return obj.course.partner.short_code
def prepare_published(self, obj): def prepare_published(self, obj):
return obj.status == CourseRun.Status.Published return obj.status == CourseRun.Status.Published
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment