Commit 695ec99a by Clinton Blackburn Committed by GitHub

Updated search endpoint to filter by default partner (#297)

If no partner is specified on the request, all results are filtered so that only items associated with the default partner are returned.

ECOM-5459
parent 657e2204
import logging
import django_filters
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db.models import QuerySet
from django.utils.translation import ugettext as _
from drf_haystack.filters import HaystackFacetFilter
from drf_haystack.filters import HaystackFacetFilter, HaystackFilter as DefaultHaystackFilter
from drf_haystack.query import FacetQueryBuilder
from dry_rest_permissions.generics import DRYPermissionFiltersBase
from guardian.shortcuts import get_objects_for_user
from rest_framework.exceptions import PermissionDenied, NotFound
from course_discovery.apps.core.models import Partner
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
logger = logging.getLogger(__name__)
......@@ -61,6 +63,18 @@ class HaystackFacetFilterWithQueries(HaystackFacetFilter):
query_builder_class = FacetQueryBuilderWithQueries
class HaystackFilter(DefaultHaystackFilter):
@staticmethod
def get_request_filters(request):
filters = HaystackFacetFilter.get_request_filters(request)
# Return data for the default partner, if no partner is requested
if not any(field in filters for field in ('partner', 'partner_exact')):
filters['partner'] = Partner.objects.get(pk=settings.DEFAULT_PARTNER_ID).short_code
return filters
class CharListFilter(django_filters.CharFilter):
def filter(self, qs, value): # pylint: disable=method-hidden
if value not in (None, ''):
......
......@@ -45,7 +45,7 @@ COURSE_RUN_FACET_FIELD_QUERIES = {
COURSE_RUN_SEARCH_FIELDS = (
'text', 'key', 'title', 'short_description', 'full_description', 'start', 'end', 'enrollment_start',
'enrollment_end', 'pacing_type', 'language', 'transcript_languages', 'marketing_url', 'content_type', 'org',
'number', 'seat_types', 'image_url', 'type', 'level_type', 'availability', 'published',
'number', 'seat_types', 'image_url', 'type', 'level_type', 'availability', 'published', 'partner',
)
PROGRAM_FACET_FIELD_OPTIONS = {
......@@ -57,7 +57,7 @@ PROGRAM_FACET_FIELD_OPTIONS = {
BASE_PROGRAM_FIELDS = (
'text', 'uuid', 'title', 'subtitle', 'type', 'marketing_url', 'content_type', 'status', 'card_image_url',
'published',
'published', 'partner',
)
PROGRAM_SEARCH_FIELDS = BASE_PROGRAM_FIELDS + ('authoring_organizations',)
......
......@@ -221,7 +221,6 @@ class ProgramCourseSerializerTests(TestCase):
class ProgramSerializerTests(TestCase):
def test_data(self):
request = make_request()
org_list = OrganizationFactory.create_batch(1)
......@@ -579,7 +578,6 @@ class AffiliateWindowSerializerTests(TestCase):
class CourseRunSearchSerializerTests(TestCase):
def test_data(self):
course_run = CourseRunFactory()
serializer = self.serialize_course_run(course_run)
......@@ -607,6 +605,7 @@ class CourseRunSearchSerializerTests(TestCase):
'level_type': course_run.level_type.name,
'availability': course_run.availability,
'published': course_run.status == CourseRun.Status.Published,
'partner': course_run.course.partner.short_code,
}
self.assertDictEqual(serializer.data, expected)
......@@ -624,52 +623,41 @@ class CourseRunSearchSerializerTests(TestCase):
class ProgramSearchSerializerTests(TestCase):
def test_data(self):
program = ProgramFactory()
authoring_organization, crediting_organization = OrganizationFactory.create_batch(2)
program.authoring_organizations.add(authoring_organization)
program.credit_backing_organizations.add(crediting_organization)
program.save()
expected_authoring_organizations = [
OrganizationSerializer(authoring_organization).data,
]
# NOTE: This serializer expects SearchQuerySet results, so we run a search on the newly-created object
# to generate such a result.
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
serializer = ProgramSearchSerializer(result)
expected = {
def _create_expected_data(self, program):
return {
'uuid': str(program.uuid),
'title': program.title,
'subtitle': program.subtitle,
'type': program.type.name,
'marketing_url': program.marketing_url,
'authoring_organizations': expected_authoring_organizations,
'authoring_organizations': OrganizationSerializer(program.authoring_organizations, many=True).data,
'content_type': 'program',
'card_image_url': program.card_image_url,
'status': program.status,
'published': program.status == Program.Status.Active,
'partner': program.partner.short_code,
}
def test_data(self):
authoring_organization, crediting_organization = OrganizationFactory.create_batch(2)
program = ProgramFactory(authoring_organizations=[authoring_organization],
credit_backing_organizations=[crediting_organization])
# NOTE: This serializer expects SearchQuerySet results, so we run a search on the newly-created object
# to generate such a result.
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
serializer = ProgramSearchSerializer(result)
expected = self._create_expected_data(program)
self.assertDictEqual(serializer.data, expected)
def test_organization_bodies_missing(self):
program = ProgramFactory()
def test_data_without_organizations(self):
""" Verify the serializer serialized programs with no associated organizations.
In such cases the organizations value should be an empty array. """
program = ProgramFactory(authoring_organizations=[], credit_backing_organizations=[])
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
result.organization_bodies = None
serializer = ProgramSearchSerializer(result)
expected = {
'uuid': str(program.uuid),
'title': program.title,
'subtitle': program.subtitle,
'type': program.type.name,
'marketing_url': program.marketing_url,
'authoring_organizations': [],
'content_type': 'program',
'card_image_url': program.card_image_url,
'status': program.status,
'published': program.status == Program.Status.Active,
}
expected = self._create_expected_data(program)
self.assertDictEqual(serializer.data, expected)
......@@ -3,12 +3,13 @@ import json
import urllib.parse
import ddt
from django.conf import settings
from django.core.urlresolvers import reverse
from haystack.query import SearchQuerySet
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import CourseRunSearchSerializer, ProgramSearchSerializer
from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD
from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD, PartnerFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import CourseRun, Program
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory
......@@ -31,8 +32,15 @@ class LoginMixin:
self.client.login(username=self.user.username, password=USER_PASSWORD)
class DefaultPartnerMixin:
def setUp(self):
super(DefaultPartnerMixin, self).setUp()
self.partner = PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
@ddt.ddt
class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase):
class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
""" Tests for CourseRunSearchViewSet. """
faceted_path = reverse('api:v1:search-course_runs-facets')
list_path = reverse('api:v1:search-course_runs-list')
......@@ -73,7 +81,8 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
""" Asserts the search functionality returns results for a generated query. """
# Generate data that should be indexed and returned by the query
course_run = CourseRunFactory(course__title='Software Testing', status=CourseRun.Status.Published)
course_run = CourseRunFactory(course__partner=self.partner, course__title='Software Testing',
status=CourseRun.Status.Published)
response = self.get_search_response('software', faceted=faceted)
self.assertEqual(response.status_code, 200)
......@@ -109,14 +118,14 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
def test_availability_faceting(self):
""" Verify the endpoint returns availability facets with the results. """
now = datetime.datetime.utcnow()
archived = CourseRunFactory(start=now - datetime.timedelta(weeks=2), end=now - datetime.timedelta(weeks=1),
status=CourseRun.Status.Published)
current = CourseRunFactory(start=now - datetime.timedelta(weeks=2), end=now + datetime.timedelta(weeks=1),
status=CourseRun.Status.Published)
starting_soon = CourseRunFactory(start=now + datetime.timedelta(days=10), end=now + datetime.timedelta(days=90),
status=CourseRun.Status.Published)
upcoming = CourseRunFactory(start=now + datetime.timedelta(days=61), end=now + datetime.timedelta(days=90),
status=CourseRun.Status.Published)
archived = CourseRunFactory(course__partner=self.partner, start=now - datetime.timedelta(weeks=2),
end=now - datetime.timedelta(weeks=1), status=CourseRun.Status.Published)
current = CourseRunFactory(course__partner=self.partner, start=now - datetime.timedelta(weeks=2),
end=now + datetime.timedelta(weeks=1), status=CourseRun.Status.Published)
starting_soon = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(days=10),
end=now + datetime.timedelta(days=90), status=CourseRun.Status.Published)
upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(days=61),
end=now + datetime.timedelta(days=90), status=CourseRun.Status.Published)
response = self.get_search_response(faceted=True)
self.assertEqual(response.status_code, 200)
......@@ -162,14 +171,12 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
self.assertDictContainsSubset(expected, response_data['queries'])
class AggregateSearchViewSet(SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase):
class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets')
def get_search_response(self, query=None):
qs = ''
if query:
qs = urllib.parse.urlencode({'q': query})
def get_search_response(self, querystring=None):
querystring = querystring or {}
qs = urllib.parse.urlencode(querystring)
url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url)
......@@ -177,14 +184,40 @@ class AggregateSearchViewSet(SerializationMixin, LoginMixin, ElasticsearchTestMi
def test_results_only_include_published_objects(self):
""" Verify the search results only include items with status set to 'Published'. """
# These items should NOT be in the results
CourseRunFactory(status=CourseRun.Status.Unpublished)
ProgramFactory(status=Program.Status.Unpublished)
CourseRunFactory(course__partner=self.partner, status=CourseRun.Status.Unpublished)
ProgramFactory(partner=self.partner, status=Program.Status.Unpublished)
course_run = CourseRunFactory(status=CourseRun.Status.Published)
program = ProgramFactory(status=Program.Status.Active)
course_run = CourseRunFactory(course__partner=self.partner, status=CourseRun.Status.Published)
program = ProgramFactory(partner=self.partner, status=Program.Status.Active)
response = self.get_search_response()
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'],
[self.serialize_course_run(course_run), self.serialize_program(program)])
def test_results_filtered_by_default_partner(self):
""" Verify the search results only include items related to the default partner if no partner is
specified on the request. If a partner is included, the data should be filtered to the requested partner. """
course_run = CourseRunFactory(course__partner=self.partner, status=CourseRun.Status.Published)
program = ProgramFactory(partner=self.partner, status=Program.Status.Active)
# This data should NOT be in the results
other_partner = PartnerFactory()
other_course_run = CourseRunFactory(course__partner=other_partner, status=CourseRun.Status.Published)
other_program = ProgramFactory(partner=other_partner, status=Program.Status.Active)
self.assertNotEqual(other_program.partner.short_code, self.partner.short_code)
self.assertNotEqual(other_course_run.course.partner.short_code, self.partner.short_code)
response = self.get_search_response()
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'],
[self.serialize_program(program), self.serialize_course_run(course_run)])
# Filter results by partner
response = self.get_search_response({'partner': other_partner.short_code})
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'],
[self.serialize_course_run(other_course_run), self.serialize_program(other_program)])
......@@ -12,7 +12,6 @@ from django.db.models import Q
from django.db.models.functions import Lower
from django.http import HttpResponse
from django.shortcuts import get_object_or_404
from drf_haystack.filters import HaystackFilter
from drf_haystack.mixins import FacetMixin
from drf_haystack.viewsets import HaystackViewSet
from dry_rest_permissions.generics import DRYPermissions
......@@ -455,7 +454,7 @@ class AffiliateWindowViewSet(viewsets.ViewSet):
class BaseHaystackViewSet(FacetMixin, HaystackViewSet):
document_uid_field = 'key'
facet_filter_backends = [filters.HaystackFacetFilterWithQueries, HaystackFilter]
facet_filter_backends = [filters.HaystackFacetFilterWithQueries, filters.HaystackFilter]
load_all = True
lookup_field = 'key'
permission_classes = (IsAuthenticated,)
......@@ -508,7 +507,7 @@ class BaseHaystackViewSet(FacetMixin, HaystackViewSet):
return super(BaseHaystackViewSet, self).facets(request)
def filter_facet_queryset(self, queryset):
queryset = super(BaseHaystackViewSet, self).filter_facet_queryset(queryset)
queryset = super().filter_facet_queryset(queryset)
facet_serializer_cls = self.get_facet_serializer_class()
field_queries = getattr(facet_serializer_cls.Meta, 'field_queries', {})
......
......@@ -111,9 +111,12 @@ class CourseRunIndex(BaseCourseIndex, indexes.Indexable):
seat_types = indexes.MultiValueField(model_attr='seat_types', null=True, faceted=True)
type = indexes.CharField(model_attr='type', null=True, faceted=True)
image_url = indexes.CharField(model_attr='card_image_url', null=True)
partner = indexes.CharField(model_attr='course__partner__short_code', null=True, faceted=True)
partner = indexes.CharField(null=True, faceted=True)
published = indexes.BooleanField(null=False, faceted=True)
def prepare_partner(self, obj):
return obj.course.partner.short_code
def prepare_published(self, obj):
return obj.status == CourseRun.Status.Published
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment