Commit c5069d83 by Matthew Piatetsky

Filter typeahead results by partner

parent 58b5c4c2
......@@ -17,7 +17,9 @@ from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWOR
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import CourseRun, Program, ProgramType
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory, OrganizationFactory
from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, ProgramFactory, OrganizationFactory
)
from course_discovery.apps.edx_haystack_extensions.models import ElasticsearchBoostConfig
......@@ -316,7 +318,8 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
self.assertEqual(response.data['objects']['results'], expected)
class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase):
class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
path = reverse('api:v1:search-typeahead')
function_score = {
'functions': [
......@@ -327,11 +330,11 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
'boost': 1.0, 'score_mode': 'sum', 'boost_mode': 'sum',
}
def get_typeahead_response(self, query=None):
def get_typeahead_response(self, query=None, partner=None):
qs = ''
query_dict = {'q': query, 'partner': partner or self.partner.short_code}
if query:
qs = urllib.parse.urlencode({'q': query})
qs = urllib.parse.urlencode(query_dict)
url = '{path}?{qs}'.format(path=self.path, qs=qs)
config = ElasticsearchBoostConfig.get_solo()
......@@ -342,8 +345,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_typeahead(self):
""" Test typeahead response. """
title = "Python"
course_run = CourseRunFactory(title=title)
program = ProgramFactory(title=title, status=ProgramStatus.Active)
course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
......@@ -355,8 +358,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
RESULT_COUNT = TypeaheadSearchView.RESULT_COUNT
title = "Test"
for i in range(RESULT_COUNT + 1):
CourseRunFactory(title="{}{}".format(title, i))
ProgramFactory(title="{}{}".format(title, i), status=ProgramStatus.Active)
CourseRunFactory(title="{}{}".format(title, i), course__partner=self.partner)
ProgramFactory(title="{}{}".format(title, i), status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
......@@ -367,9 +370,14 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Test typeahead response with multiple authoring organizations. """
title = "Design"
authoring_organizations = OrganizationFactory.create_batch(3)
course_run = CourseRunFactory(title=title, authoring_organizations=authoring_organizations)
course_run = CourseRunFactory(
title=title,
authoring_organizations=authoring_organizations,
course__partner=self.partner
)
program = ProgramFactory(
title=title, authoring_organizations=authoring_organizations, status=ProgramStatus.Active
title=title, authoring_organizations=authoring_organizations,
status=ProgramStatus.Active, partner=self.partner
)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
......@@ -380,8 +388,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_partial_term_search(self):
""" Test typeahead response with partial term search. """
title = "Learn Data Science"
course_run = CourseRunFactory(title=title)
program = ProgramFactory(title=title, status=ProgramStatus.Active)
course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
query = "Data Sci"
response = self.get_typeahead_response(query)
self.assertEqual(response.status_code, 200)
......@@ -396,11 +404,11 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Verify that typeahead does not return unpublished or hidden courses
or programs that are not active. """
title = "Supply Chain"
course_run = CourseRunFactory(title=title)
CourseRunFactory(title=title + "_unpublished", status=CourseRunStatus.Unpublished)
CourseRunFactory(title=title + "_hidden", hidden=True)
program = ProgramFactory(title=title, status=ProgramStatus.Active)
ProgramFactory(title=title + "_unpublished", status=ProgramStatus.Unpublished)
course_run = CourseRunFactory(title=title, course__partner=self.partner)
CourseRunFactory(title=title + "_unpublished", status=CourseRunStatus.Unpublished, course__partner=self.partner)
CourseRunFactory(title=title + "_hidden", hidden=True, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
ProgramFactory(title=title + "_unpublished", status=ProgramStatus.Unpublished, partner=self.partner)
query = "Supply"
response = self.get_typeahead_response(query)
self.assertEqual(response.status_code, 200)
......@@ -413,7 +421,7 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Verify the view raises an error if the 'q' query string parameter is not provided. """
response = self.get_typeahead_response()
self.assertEqual(response.status_code, 400)
self.assertDictEqual(response.data, {'detail': 'The \'q\' querystring parameter is required for searching.'})
self.assertEqual(response.data, ["The 'q' querystring parameter is required for searching."])
def test_micromasters_boosting(self):
""" Verify micromasters are boosted over xseries."""
......@@ -421,9 +429,13 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
ProgramFactory(
title=title + "1",
status=ProgramStatus.Active,
type=ProgramType.objects.get(name='MicroMasters')
type=ProgramType.objects.get(name='MicroMasters'),
partner=self.partner
)
ProgramFactory(
title=title + "2", status=ProgramStatus.Active,
type=ProgramType.objects.get(name='XSeries'), partner=self.partner
)
ProgramFactory(title=title + "2", status=ProgramStatus.Active, type=ProgramType.objects.get(name='XSeries'))
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
......@@ -434,8 +446,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Verify upcoming courses are boosted over past courses."""
title = "start"
now = datetime.datetime.utcnow()
CourseRunFactory(title=title + "1", start=now - datetime.timedelta(weeks=10))
CourseRunFactory(title=title + "2", start=now + datetime.timedelta(weeks=1))
CourseRunFactory(title=title + "1", start=now - datetime.timedelta(weeks=10), course__partner=self.partner)
CourseRunFactory(title=title + "2", start=now + datetime.timedelta(weeks=1), course__partner=self.partner)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
......@@ -444,8 +456,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_self_paced_boosting(self):
""" Verify that self paced courses are boosted over instructor led courses."""
title = "paced"
CourseRunFactory(title=title + "1", pacing_type='instructor_paced')
CourseRunFactory(title=title + "2", pacing_type='self_paced')
CourseRunFactory(title=title + "1", pacing_type='instructor_paced', course__partner=self.partner)
CourseRunFactory(title=title + "2", pacing_type='self_paced', course__partner=self.partner)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
......@@ -454,8 +466,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_typeahead_authoring_organizations_partial_search(self):
""" Test typeahead response with partial organization matching. """
authoring_organizations = OrganizationFactory.create_batch(3)
course_run = CourseRunFactory(authoring_organizations=authoring_organizations)
program = ProgramFactory(authoring_organizations=authoring_organizations)
course_run = CourseRunFactory(authoring_organizations=authoring_organizations, course__partner=self.partner)
program = ProgramFactory(authoring_organizations=authoring_organizations, partner=self.partner)
partial_key = authoring_organizations[0].key[0:5]
response = self.get_typeahead_response(partial_key)
......@@ -472,19 +484,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
HarvardX = OrganizationFactory(key='HarvardX')
mit_run = CourseRunFactory(
authoring_organizations=[MITx, HarvardX],
title='MIT Testing1'
title='MIT Testing1',
course__partner=self.partner
)
harvard_run = CourseRunFactory(
authoring_organizations=[HarvardX],
title='MIT Testing2'
title='MIT Testing2',
course__partner=self.partner
)
mit_program = ProgramFactory(
authoring_organizations=[MITx, HarvardX],
title='MIT Testing1'
title='MIT Testing1',
partner=self.partner
)
harvard_program = ProgramFactory(
authoring_organizations=[HarvardX],
title='MIT Testing2'
title='MIT Testing2',
partner=self.partner
)
response = self.get_typeahead_response('mit')
self.assertEqual(response.status_code, 200)
......@@ -495,3 +511,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
self.serialize_program(harvard_program)]
}
self.assertDictEqual(response.data, expected)
def test_typeahead_partner_filter(self):
""" Ensure that a partner param limits results to that partner. """
course_runs = []
programs = []
for partner in ['edx', 'other']:
title = 'Belongs to partner ' + partner
partner = PartnerFactory(short_code=partner)
course_runs.append(CourseRunFactory(title=title, course=CourseFactory(partner=partner)))
programs.append(ProgramFactory(
title=title, partner=partner,
status=ProgramStatus.Active
))
response = self.get_typeahead_response('partner', 'edx')
self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0]
edx_program = programs[0]
self.assertDictEqual(response.data, {'course_runs': [self.serialize_course_run(edx_course_run)],
'programs': [self.serialize_program(edx_program)]})
import logging
from django.conf import settings
from django.contrib.auth import get_user_model
from course_discovery.apps.api import serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
logger = logging.getLogger(__name__)
User = get_user_model()
......@@ -43,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
from django.conf import settings
from django.db.models.functions import Lower
from rest_framework import viewsets, status
from rest_framework.decorators import list_route
......@@ -7,16 +6,14 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.v1.views import get_query_param
from course_discovery.apps.core.models import Partner
from course_discovery.apps.api.v1.views import get_query_param, PartnerMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member
class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
class CourseRunViewSet(PartnerMixin, viewsets.ReadOnlyModelViewSet):
""" CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter
......@@ -27,19 +24,6 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
queryset = CourseRun.objects.all().order_by(Lower('key'))
serializer_class = serializers.CourseRunWithProgramsSerializer
def _get_partner(self):
""" Return the partner for the code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner')
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
def get_queryset(self):
""" List one course run
---
......@@ -52,7 +36,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
multiple: false
"""
q = self.request.query_params.get('q', None)
partner = self._get_partner()
partner = self.get_partner()
if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
......@@ -167,7 +151,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids')
partner = self._get_partner()
partner = self.get_partner()
if query and course_run_ids:
course_run_ids = course_run_ids.split(',')
......
......@@ -5,7 +5,7 @@ from haystack.inputs import AutoQuery
from haystack.query import SearchQuerySet
from rest_framework import status
from rest_framework.decorators import list_route
from rest_framework.exceptions import ParseError
from rest_framework.exceptions import ParseError, ValidationError
from rest_framework.filters import OrderingFilter
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
......@@ -13,6 +13,7 @@ from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
......@@ -123,12 +124,12 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(APIView):
class TypeaheadSearchView(PartnerMixin, APIView):
""" Typeahead for courses and programs. """
RESULT_COUNT = 3
permission_classes = (IsAuthenticated,)
def get_results(self, query):
def get_results(self, query, partner):
sqs = SearchQuerySet()
clean_query = sqs.query.clean(query)
......@@ -137,14 +138,14 @@ class TypeaheadSearchView(APIView):
SQ(course_key=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query)
)
course_runs = course_runs.filter(published=True).exclude(hidden=True)
course_runs = course_runs.filter(published=True).exclude(hidden=True).filter(partner=partner.short_code)
course_runs = course_runs[:self.RESULT_COUNT]
programs = sqs.models(Program).filter(
SQ(title_autocomplete=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query)
)
programs = programs.filter(status=ProgramStatus.Active)
programs = programs.filter(status=ProgramStatus.Active).filter(partner=partner.short_code)
programs = programs[:self.RESULT_COUNT]
return course_runs, programs
......@@ -165,11 +166,17 @@ class TypeaheadSearchView(APIView):
paramType: query
required: true
type: string
- name: partner
description: "Partner short code"
paramType: query
required: false
type: string
"""
query = request.query_params.get('q')
partner = self.get_partner()
if not query:
raise ParseError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query)
raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query, partner)
data = {'course_runs': course_runs, 'programs': programs}
serializer = serializers.TypeaheadSearchSerializer(data)
return Response(serializer.data, status=status.HTTP_200_OK)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment