Commit c5069d83 by Matthew Piatetsky

Filter typeahead results by partner

parent 58b5c4c2
...@@ -17,7 +17,9 @@ from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWOR ...@@ -17,7 +17,9 @@ from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWOR
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import CourseRun, Program, ProgramType from course_discovery.apps.course_metadata.models import CourseRun, Program, ProgramType
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory, OrganizationFactory from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, ProgramFactory, OrganizationFactory
)
from course_discovery.apps.edx_haystack_extensions.models import ElasticsearchBoostConfig from course_discovery.apps.edx_haystack_extensions.models import ElasticsearchBoostConfig
...@@ -316,7 +318,8 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -316,7 +318,8 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
self.assertEqual(response.data['objects']['results'], expected) self.assertEqual(response.data['objects']['results'], expected)
class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase): class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
path = reverse('api:v1:search-typeahead') path = reverse('api:v1:search-typeahead')
function_score = { function_score = {
'functions': [ 'functions': [
...@@ -327,11 +330,11 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -327,11 +330,11 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
'boost': 1.0, 'score_mode': 'sum', 'boost_mode': 'sum', 'boost': 1.0, 'score_mode': 'sum', 'boost_mode': 'sum',
} }
def get_typeahead_response(self, query=None): def get_typeahead_response(self, query=None, partner=None):
qs = '' qs = ''
query_dict = {'q': query, 'partner': partner or self.partner.short_code}
if query: if query:
qs = urllib.parse.urlencode({'q': query}) qs = urllib.parse.urlencode(query_dict)
url = '{path}?{qs}'.format(path=self.path, qs=qs) url = '{path}?{qs}'.format(path=self.path, qs=qs)
config = ElasticsearchBoostConfig.get_solo() config = ElasticsearchBoostConfig.get_solo()
...@@ -342,8 +345,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -342,8 +345,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_typeahead(self): def test_typeahead(self):
""" Test typeahead response. """ """ Test typeahead response. """
title = "Python" title = "Python"
course_run = CourseRunFactory(title=title) course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active) program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title) response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
...@@ -355,8 +358,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -355,8 +358,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
RESULT_COUNT = TypeaheadSearchView.RESULT_COUNT RESULT_COUNT = TypeaheadSearchView.RESULT_COUNT
title = "Test" title = "Test"
for i in range(RESULT_COUNT + 1): for i in range(RESULT_COUNT + 1):
CourseRunFactory(title="{}{}".format(title, i)) CourseRunFactory(title="{}{}".format(title, i), course__partner=self.partner)
ProgramFactory(title="{}{}".format(title, i), status=ProgramStatus.Active) ProgramFactory(title="{}{}".format(title, i), status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title) response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
...@@ -367,9 +370,14 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -367,9 +370,14 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Test typeahead response with multiple authoring organizations. """ """ Test typeahead response with multiple authoring organizations. """
title = "Design" title = "Design"
authoring_organizations = OrganizationFactory.create_batch(3) authoring_organizations = OrganizationFactory.create_batch(3)
course_run = CourseRunFactory(title=title, authoring_organizations=authoring_organizations) course_run = CourseRunFactory(
title=title,
authoring_organizations=authoring_organizations,
course__partner=self.partner
)
program = ProgramFactory( program = ProgramFactory(
title=title, authoring_organizations=authoring_organizations, status=ProgramStatus.Active title=title, authoring_organizations=authoring_organizations,
status=ProgramStatus.Active, partner=self.partner
) )
response = self.get_typeahead_response(title) response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -380,8 +388,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -380,8 +388,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_partial_term_search(self): def test_partial_term_search(self):
""" Test typeahead response with partial term search. """ """ Test typeahead response with partial term search. """
title = "Learn Data Science" title = "Learn Data Science"
course_run = CourseRunFactory(title=title) course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active) program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
query = "Data Sci" query = "Data Sci"
response = self.get_typeahead_response(query) response = self.get_typeahead_response(query)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -396,11 +404,11 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -396,11 +404,11 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Verify that typeahead does not return unpublished or hidden courses """ Verify that typeahead does not return unpublished or hidden courses
or programs that are not active. """ or programs that are not active. """
title = "Supply Chain" title = "Supply Chain"
course_run = CourseRunFactory(title=title) course_run = CourseRunFactory(title=title, course__partner=self.partner)
CourseRunFactory(title=title + "_unpublished", status=CourseRunStatus.Unpublished) CourseRunFactory(title=title + "_unpublished", status=CourseRunStatus.Unpublished, course__partner=self.partner)
CourseRunFactory(title=title + "_hidden", hidden=True) CourseRunFactory(title=title + "_hidden", hidden=True, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active) program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
ProgramFactory(title=title + "_unpublished", status=ProgramStatus.Unpublished) ProgramFactory(title=title + "_unpublished", status=ProgramStatus.Unpublished, partner=self.partner)
query = "Supply" query = "Supply"
response = self.get_typeahead_response(query) response = self.get_typeahead_response(query)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -413,7 +421,7 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -413,7 +421,7 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Verify the view raises an error if the 'q' query string parameter is not provided. """ """ Verify the view raises an error if the 'q' query string parameter is not provided. """
response = self.get_typeahead_response() response = self.get_typeahead_response()
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertDictEqual(response.data, {'detail': 'The \'q\' querystring parameter is required for searching.'}) self.assertEqual(response.data, ["The 'q' querystring parameter is required for searching."])
def test_micromasters_boosting(self): def test_micromasters_boosting(self):
""" Verify micromasters are boosted over xseries.""" """ Verify micromasters are boosted over xseries."""
...@@ -421,9 +429,13 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -421,9 +429,13 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
ProgramFactory( ProgramFactory(
title=title + "1", title=title + "1",
status=ProgramStatus.Active, status=ProgramStatus.Active,
type=ProgramType.objects.get(name='MicroMasters') type=ProgramType.objects.get(name='MicroMasters'),
partner=self.partner
)
ProgramFactory(
title=title + "2", status=ProgramStatus.Active,
type=ProgramType.objects.get(name='XSeries'), partner=self.partner
) )
ProgramFactory(title=title + "2", status=ProgramStatus.Active, type=ProgramType.objects.get(name='XSeries'))
response = self.get_typeahead_response(title) response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
...@@ -434,8 +446,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -434,8 +446,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
""" Verify upcoming courses are boosted over past courses.""" """ Verify upcoming courses are boosted over past courses."""
title = "start" title = "start"
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
CourseRunFactory(title=title + "1", start=now - datetime.timedelta(weeks=10)) CourseRunFactory(title=title + "1", start=now - datetime.timedelta(weeks=10), course__partner=self.partner)
CourseRunFactory(title=title + "2", start=now + datetime.timedelta(weeks=1)) CourseRunFactory(title=title + "2", start=now + datetime.timedelta(weeks=1), course__partner=self.partner)
response = self.get_typeahead_response(title) response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
...@@ -444,8 +456,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -444,8 +456,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_self_paced_boosting(self): def test_self_paced_boosting(self):
""" Verify that self paced courses are boosted over instructor led courses.""" """ Verify that self paced courses are boosted over instructor led courses."""
title = "paced" title = "paced"
CourseRunFactory(title=title + "1", pacing_type='instructor_paced') CourseRunFactory(title=title + "1", pacing_type='instructor_paced', course__partner=self.partner)
CourseRunFactory(title=title + "2", pacing_type='self_paced') CourseRunFactory(title=title + "2", pacing_type='self_paced', course__partner=self.partner)
response = self.get_typeahead_response(title) response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
...@@ -454,8 +466,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -454,8 +466,8 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
def test_typeahead_authoring_organizations_partial_search(self): def test_typeahead_authoring_organizations_partial_search(self):
""" Test typeahead response with partial organization matching. """ """ Test typeahead response with partial organization matching. """
authoring_organizations = OrganizationFactory.create_batch(3) authoring_organizations = OrganizationFactory.create_batch(3)
course_run = CourseRunFactory(authoring_organizations=authoring_organizations) course_run = CourseRunFactory(authoring_organizations=authoring_organizations, course__partner=self.partner)
program = ProgramFactory(authoring_organizations=authoring_organizations) program = ProgramFactory(authoring_organizations=authoring_organizations, partner=self.partner)
partial_key = authoring_organizations[0].key[0:5] partial_key = authoring_organizations[0].key[0:5]
response = self.get_typeahead_response(partial_key) response = self.get_typeahead_response(partial_key)
...@@ -472,19 +484,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -472,19 +484,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
HarvardX = OrganizationFactory(key='HarvardX') HarvardX = OrganizationFactory(key='HarvardX')
mit_run = CourseRunFactory( mit_run = CourseRunFactory(
authoring_organizations=[MITx, HarvardX], authoring_organizations=[MITx, HarvardX],
title='MIT Testing1' title='MIT Testing1',
course__partner=self.partner
) )
harvard_run = CourseRunFactory( harvard_run = CourseRunFactory(
authoring_organizations=[HarvardX], authoring_organizations=[HarvardX],
title='MIT Testing2' title='MIT Testing2',
course__partner=self.partner
) )
mit_program = ProgramFactory( mit_program = ProgramFactory(
authoring_organizations=[MITx, HarvardX], authoring_organizations=[MITx, HarvardX],
title='MIT Testing1' title='MIT Testing1',
partner=self.partner
) )
harvard_program = ProgramFactory( harvard_program = ProgramFactory(
authoring_organizations=[HarvardX], authoring_organizations=[HarvardX],
title='MIT Testing2' title='MIT Testing2',
partner=self.partner
) )
response = self.get_typeahead_response('mit') response = self.get_typeahead_response('mit')
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -495,3 +511,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics ...@@ -495,3 +511,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
self.serialize_program(harvard_program)] self.serialize_program(harvard_program)]
} }
self.assertDictEqual(response.data, expected) self.assertDictEqual(response.data, expected)
def test_typeahead_partner_filter(self):
""" Ensure that a partner param limits results to that partner. """
course_runs = []
programs = []
for partner in ['edx', 'other']:
title = 'Belongs to partner ' + partner
partner = PartnerFactory(short_code=partner)
course_runs.append(CourseRunFactory(title=title, course=CourseFactory(partner=partner)))
programs.append(ProgramFactory(
title=title, partner=partner,
status=ProgramStatus.Active
))
response = self.get_typeahead_response('partner', 'edx')
self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0]
edx_program = programs[0]
self.assertDictEqual(response.data, {'course_runs': [self.serialize_course_run(edx_course_run)],
'programs': [self.serialize_program(edx_program)]})
import logging import logging
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from course_discovery.apps.api import serializers from course_discovery.apps.api import serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.utils import cast2int from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
...@@ -43,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset): ...@@ -43,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course']) queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course']) queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
from django.conf import settings
from django.db.models.functions import Lower from django.db.models.functions import Lower
from rest_framework import viewsets, status from rest_framework import viewsets, status
from rest_framework.decorators import list_route from rest_framework.decorators import list_route
...@@ -7,16 +6,14 @@ from rest_framework.permissions import IsAuthenticated ...@@ -7,16 +6,14 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError from course_discovery.apps.api.v1.views import get_query_param, PartnerMixin
from course_discovery.apps.api.v1.views import get_query_param
from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.utils import SearchQuerySetWrapper from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member # pylint: disable=no-member
class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): class CourseRunViewSet(PartnerMixin, viewsets.ReadOnlyModelViewSet):
""" CourseRun resource. """ """ CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter) filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter filter_class = filters.CourseRunFilter
...@@ -27,19 +24,6 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -27,19 +24,6 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
queryset = CourseRun.objects.all().order_by(Lower('key')) queryset = CourseRun.objects.all().order_by(Lower('key'))
serializer_class = serializers.CourseRunWithProgramsSerializer serializer_class = serializers.CourseRunWithProgramsSerializer
def _get_partner(self):
""" Return the partner for the code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner')
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
def get_queryset(self): def get_queryset(self):
""" List one course run """ List one course run
--- ---
...@@ -52,7 +36,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -52,7 +36,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
multiple: false multiple: false
""" """
q = self.request.query_params.get('q', None) q = self.request.query_params.get('q', None)
partner = self._get_partner() partner = self.get_partner()
if q: if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code)) qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
...@@ -167,7 +151,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -167,7 +151,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
""" """
query = request.GET.get('query') query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids') course_run_ids = request.GET.get('course_run_ids')
partner = self._get_partner() partner = self.get_partner()
if query and course_run_ids: if query and course_run_ids:
course_run_ids = course_run_ids.split(',') course_run_ids = course_run_ids.split(',')
......
...@@ -5,7 +5,7 @@ from haystack.inputs import AutoQuery ...@@ -5,7 +5,7 @@ from haystack.inputs import AutoQuery
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from rest_framework import status from rest_framework import status
from rest_framework.decorators import list_route from rest_framework.decorators import list_route
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError, ValidationError
from rest_framework.filters import OrderingFilter from rest_framework.filters import OrderingFilter
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
...@@ -13,6 +13,7 @@ from rest_framework.views import APIView ...@@ -13,6 +13,7 @@ from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import PageNumberPagination from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
...@@ -123,12 +124,12 @@ class AggregateSearchViewSet(BaseHaystackViewSet): ...@@ -123,12 +124,12 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(APIView): class TypeaheadSearchView(PartnerMixin, APIView):
""" Typeahead for courses and programs. """ """ Typeahead for courses and programs. """
RESULT_COUNT = 3 RESULT_COUNT = 3
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
def get_results(self, query): def get_results(self, query, partner):
sqs = SearchQuerySet() sqs = SearchQuerySet()
clean_query = sqs.query.clean(query) clean_query = sqs.query.clean(query)
...@@ -137,14 +138,14 @@ class TypeaheadSearchView(APIView): ...@@ -137,14 +138,14 @@ class TypeaheadSearchView(APIView):
SQ(course_key=clean_query) | SQ(course_key=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query) SQ(authoring_organizations_autocomplete=clean_query)
) )
course_runs = course_runs.filter(published=True).exclude(hidden=True) course_runs = course_runs.filter(published=True).exclude(hidden=True).filter(partner=partner.short_code)
course_runs = course_runs[:self.RESULT_COUNT] course_runs = course_runs[:self.RESULT_COUNT]
programs = sqs.models(Program).filter( programs = sqs.models(Program).filter(
SQ(title_autocomplete=clean_query) | SQ(title_autocomplete=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query) SQ(authoring_organizations_autocomplete=clean_query)
) )
programs = programs.filter(status=ProgramStatus.Active) programs = programs.filter(status=ProgramStatus.Active).filter(partner=partner.short_code)
programs = programs[:self.RESULT_COUNT] programs = programs[:self.RESULT_COUNT]
return course_runs, programs return course_runs, programs
...@@ -165,11 +166,17 @@ class TypeaheadSearchView(APIView): ...@@ -165,11 +166,17 @@ class TypeaheadSearchView(APIView):
paramType: query paramType: query
required: true required: true
type: string type: string
- name: partner
description: "Partner short code"
paramType: query
required: false
type: string
""" """
query = request.query_params.get('q') query = request.query_params.get('q')
partner = self.get_partner()
if not query: if not query:
raise ParseError("The 'q' querystring parameter is required for searching.") raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query) course_runs, programs = self.get_results(query, partner)
data = {'course_runs': course_runs, 'programs': programs} data = {'course_runs': course_runs, 'programs': programs}
serializer = serializers.TypeaheadSearchSerializer(data) serializer = serializers.TypeaheadSearchSerializer(data)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment