Commit 06eb2bc4 by Matthew Piatetsky

Add partial term matching to typeahead

ECOM-4738
parent 9e2663c7
...@@ -913,20 +913,6 @@ class CourseRunSearchSerializer(HaystackSerializer): ...@@ -913,20 +913,6 @@ class CourseRunSearchSerializer(HaystackSerializer):
index_classes = [CourseRunIndex] index_classes = [CourseRunIndex]
class TypeaheadCourseRunSearchSerializer(HaystackSerializer):
additional_details = serializers.SerializerMethodField()
def get_additional_details(self, result):
""" Value of the grey text next to the typeahead result title. """
return result.org
class Meta:
field_aliases = COMMON_SEARCH_FIELD_ALIASES
fields = ['key', 'title', 'content_type']
ignore_fields = COMMON_IGNORED_FIELDS
index_classes = [CourseRunIndex]
class CourseRunFacetSerializer(BaseHaystackFacetSerializer): class CourseRunFacetSerializer(BaseHaystackFacetSerializer):
serialize_objects = True serialize_objects = True
...@@ -952,21 +938,6 @@ class ProgramSearchSerializer(HaystackSerializer): ...@@ -952,21 +938,6 @@ class ProgramSearchSerializer(HaystackSerializer):
index_classes = [ProgramIndex] index_classes = [ProgramIndex]
class TypeaheadProgramSearchSerializer(HaystackSerializer):
additional_details = serializers.SerializerMethodField()
def get_additional_details(self, result):
""" Value of the grey text next to the typeahead result title. """
authoring_organizations = [json.loads(org) for org in result.authoring_organization_bodies]
return ', '.join([org['key'] for org in authoring_organizations])
class Meta:
field_aliases = COMMON_SEARCH_FIELD_ALIASES
fields = ['uuid', 'title', 'content_type', 'type']
ignore_fields = COMMON_IGNORED_FIELDS
index_classes = [ProgramIndex]
class ProgramFacetSerializer(BaseHaystackFacetSerializer): class ProgramFacetSerializer(BaseHaystackFacetSerializer):
serialize_objects = True serialize_objects = True
...@@ -990,15 +961,32 @@ class AggregateSearchSerializer(HaystackSerializer): ...@@ -990,15 +961,32 @@ class AggregateSearchSerializer(HaystackSerializer):
} }
class TypeaheadSearchSerializer(HaystackSerializer): class TypeaheadCourseRunSearchSerializer(serializers.Serializer):
org = serializers.CharField()
title = serializers.CharField()
key = serializers.CharField()
class Meta: class Meta:
field_aliases = COMMON_SEARCH_FIELD_ALIASES fields = ['key', 'title']
fields = COURSE_RUN_SEARCH_FIELDS + PROGRAM_SEARCH_FIELDS
ignore_fields = COMMON_IGNORED_FIELDS
serializers = { class TypeaheadProgramSearchSerializer(serializers.Serializer):
ProgramIndex: TypeaheadProgramSearchSerializer, orgs = serializers.SerializerMethodField()
CourseRunIndex: TypeaheadCourseRunSearchSerializer, uuid = serializers.CharField()
} title = serializers.CharField()
type = serializers.CharField()
def get_orgs(self, result):
authoring_organizations = [json.loads(org) for org in result.authoring_organization_bodies]
return [org['key'] for org in authoring_organizations]
class Meta:
fields = ['uuid', 'title', 'type']
class TypeaheadSearchSerializer(serializers.Serializer):
course_runs = TypeaheadCourseRunSearchSerializer(many=True)
programs = TypeaheadProgramSearchSerializer(many=True)
class AggregateFacetSearchSerializer(BaseHaystackFacetSerializer): class AggregateFacetSearchSerializer(BaseHaystackFacetSerializer):
......
...@@ -1110,8 +1110,7 @@ class TypeaheadCourseRunSearchSerializerTests(TestCase): ...@@ -1110,8 +1110,7 @@ class TypeaheadCourseRunSearchSerializerTests(TestCase):
expected = { expected = {
'key': course_run.key, 'key': course_run.key,
'title': course_run.title, 'title': course_run.title,
'content_type': 'courserun', 'org': course_run_key.org
'additional_details': course_run_key.org
} }
self.assertDictEqual(serialized_course.data, expected) self.assertDictEqual(serialized_course.data, expected)
...@@ -1128,16 +1127,13 @@ class TypeaheadProgramSearchSerializerTests(TestCase): ...@@ -1128,16 +1127,13 @@ class TypeaheadProgramSearchSerializerTests(TestCase):
'uuid': str(program.uuid), 'uuid': str(program.uuid),
'title': program.title, 'title': program.title,
'type': program.type.name, 'type': program.type.name,
'content_type': 'program', 'orgs': list(program.authoring_organizations.all().values_list('key', flat=True))
'additional_details': program.authoring_organizations.first().key
} }
def test_data(self): def test_data(self):
authoring_organization = OrganizationFactory() authoring_organization = OrganizationFactory()
program = ProgramFactory(authoring_organizations=[authoring_organization]) program = ProgramFactory(authoring_organizations=[authoring_organization])
serialized_program = self.serialize_program(program) serialized_program = self.serialize_program(program)
expected = self._create_expected_data(program) expected = self._create_expected_data(program)
self.assertDictEqual(serialized_program.data, expected) self.assertDictEqual(serialized_program.data, expected)
...@@ -1145,8 +1141,8 @@ class TypeaheadProgramSearchSerializerTests(TestCase): ...@@ -1145,8 +1141,8 @@ class TypeaheadProgramSearchSerializerTests(TestCase):
authoring_organizations = OrganizationFactory.create_batch(3) authoring_organizations = OrganizationFactory.create_batch(3)
program = ProgramFactory(authoring_organizations=authoring_organizations) program = ProgramFactory(authoring_organizations=authoring_organizations)
serialized_program = self.serialize_program(program) serialized_program = self.serialize_program(program)
expected = ', '.join([org.key for org in authoring_organizations]) expected = [org.key for org in authoring_organizations]
self.assertEqual(serialized_program.data['additional_details'], expected) self.assertEqual(serialized_program.data['orgs'], expected)
def serialize_program(self, program): def serialize_program(self, program):
""" Serializes the given `Program` as a typeahead result. """ """ Serializes the given `Program` as a typeahead result. """
......
...@@ -10,8 +10,7 @@ from rest_framework.test import APITestCase ...@@ -10,8 +10,7 @@ from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer, from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer) TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer)
from course_discovery.apps.api.v1.views import RESULT_COUNT from course_discovery.apps.api.v1.views import TypeaheadSearchView
from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD, PartnerFactory from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD, PartnerFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
...@@ -33,15 +32,11 @@ class TypeaheadSerializationMixin: ...@@ -33,15 +32,11 @@ class TypeaheadSerializationMixin:
def serialize_course_run(self, course_run): def serialize_course_run(self, course_run):
result = SearchQuerySet().models(CourseRun).filter(key=course_run.key)[0] result = SearchQuerySet().models(CourseRun).filter(key=course_run.key)[0]
data = TypeaheadCourseRunSearchSerializer(result).data data = TypeaheadCourseRunSearchSerializer(result).data
# Items are grouped by content type so we don't need it in the response
data.pop('content_type')
return data return data
def serialize_program(self, program): def serialize_program(self, program):
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0] result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
data = TypeaheadProgramSearchSerializer(result).data data = TypeaheadProgramSearchSerializer(result).data
# Items are grouped by content type so we don't need it in the response
data.pop('content_type')
return data return data
...@@ -300,27 +295,37 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -300,27 +295,37 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
[self.serialize_course_run(course_run), self.serialize_program(program)]) [self.serialize_course_run(course_run), self.serialize_program(program)])
class TypeaheadSearchViewSet(TypeaheadSerializationMixin, LoginMixin, APITestCase): class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, APITestCase):
path = reverse('api:v1:search-typeahead-list') path = reverse('api:v1:search-typeahead')
def get_typeahead_response(self, query=None):
qs = ''
def get_typeahead_response(self): if query:
return self.client.get(self.path) qs = urllib.parse.urlencode({'q': query})
url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url)
def test_typeahead(self): def test_typeahead(self):
""" Test typeahead response. """ """ Test typeahead response. """
course_run = CourseRunFactory() title = "Python"
program = ProgramFactory() course_run = CourseRunFactory(title=title)
response = self.get_typeahead_response() program = ProgramFactory(title=title)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)], self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]}) 'programs': [self.serialize_program(program)]})
def test_typeahead_multiple_results(self): def test_typeahead_multiple_results(self):
""" Test typeahead response with max number of course_runs and programs. """ """ Verify the typeahead responses always returns a limited number of results, even if there are more hits. """
CourseRunFactory.create_batch(RESULT_COUNT + 1) RESULT_COUNT = TypeaheadSearchView.RESULT_COUNT
ProgramFactory.create_batch(RESULT_COUNT + 1) title = "Test"
response = self.get_typeahead_response() for i in range(RESULT_COUNT + 1):
CourseRunFactory(title="{}{}".format(title, i))
ProgramFactory(title="{}{}".format(title, i))
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
self.assertEqual(len(response_data['course_runs']), RESULT_COUNT) self.assertEqual(len(response_data['course_runs']), RESULT_COUNT)
...@@ -328,11 +333,31 @@ class TypeaheadSearchViewSet(TypeaheadSerializationMixin, LoginMixin, APITestCas ...@@ -328,11 +333,31 @@ class TypeaheadSearchViewSet(TypeaheadSerializationMixin, LoginMixin, APITestCas
def test_typeahead_multiple_authoring_organizations(self): def test_typeahead_multiple_authoring_organizations(self):
""" Test typeahead response with multiple authoring organizations. """ """ Test typeahead response with multiple authoring organizations. """
title = "Design"
authoring_organizations = OrganizationFactory.create_batch(3) authoring_organizations = OrganizationFactory.create_batch(3)
course_run = CourseRunFactory(authoring_organizations=authoring_organizations) course_run = CourseRunFactory(title=title, authoring_organizations=authoring_organizations)
program = ProgramFactory(authoring_organizations=authoring_organizations) program = ProgramFactory(title=title, authoring_organizations=authoring_organizations)
response = self.get_typeahead_response() response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)], self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]}) 'programs': [self.serialize_program(program)]})
def test_partial_term_search(self):
""" Test typeahead response with partial term search. """
title = "Learn Data Science"
course_run = CourseRunFactory(title=title)
program = ProgramFactory(title=title)
query = "Data Sci"
response = self.get_typeahead_response(query)
self.assertEqual(response.status_code, 200)
response_data = response.json()
expected_response_data = {'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]}
self.assertDictEqual(response_data, expected_response_data)
def test_exception(self):
""" Verify the view raises an error if the 'q' query string parameter is not provided. """
response = self.get_typeahead_response()
self.assertEqual(response.status_code, 400)
self.assertDictEqual(response.data, {'detail': 'The \'q\' querystring parameter is required for searching.'})
...@@ -9,6 +9,7 @@ partners_router.register(r'affiliate_window/catalogs', views.AffiliateWindowView ...@@ -9,6 +9,7 @@ partners_router.register(r'affiliate_window/catalogs', views.AffiliateWindowView
partners_urls = partners_router.urls partners_urls = partners_router.urls
urlpatterns = [ urlpatterns = [
url(r'^partners/', include(partners_urls, namespace='partners')), url(r'^partners/', include(partners_urls, namespace='partners')),
url(r'search/typeahead', views.TypeaheadSearchView.as_view(), name='search-typeahead')
] ]
router = routers.SimpleRouter() router = routers.SimpleRouter()
...@@ -18,7 +19,6 @@ router.register(r'course_runs', views.CourseRunViewSet, base_name='course_run') ...@@ -18,7 +19,6 @@ router.register(r'course_runs', views.CourseRunViewSet, base_name='course_run')
router.register(r'management', views.ManagementViewSet, base_name='management') router.register(r'management', views.ManagementViewSet, base_name='management')
router.register(r'programs', views.ProgramViewSet, base_name='program') router.register(r'programs', views.ProgramViewSet, base_name='program')
router.register(r'search/all', views.AggregateSearchViewSet, base_name='search-all') router.register(r'search/all', views.AggregateSearchViewSet, base_name='search-all')
router.register(r'search/typeahead', views.TypeaheadSearchViewSet, base_name='search-typeahead')
router.register(r'search/courses', views.CourseSearchViewSet, base_name='search-courses') router.register(r'search/courses', views.CourseSearchViewSet, base_name='search-courses')
router.register(r'search/course_runs', views.CourseRunSearchViewSet, base_name='search-course_runs') router.register(r'search/course_runs', views.CourseRunSearchViewSet, base_name='search-course_runs')
router.register(r'search/programs', views.ProgramSearchViewSet, base_name='search-programs') router.register(r'search/programs', views.ProgramSearchViewSet, base_name='search-programs')
......
...@@ -15,13 +15,14 @@ from drf_haystack.viewsets import HaystackViewSet ...@@ -15,13 +15,14 @@ from drf_haystack.viewsets import HaystackViewSet
from dry_rest_permissions.generics import DRYPermissions from dry_rest_permissions.generics import DRYPermissions
from edx_rest_framework_extensions.permissions import IsSuperuser from edx_rest_framework_extensions.permissions import IsSuperuser
from haystack.inputs import AutoQuery from haystack.inputs import AutoQuery
from haystack.query import SQ from haystack.query import SQ, SearchQuerySet
from rest_framework import status, viewsets from rest_framework import status, viewsets
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.exceptions import PermissionDenied, ParseError from rest_framework.exceptions import PermissionDenied, ParseError
from rest_framework.filters import DjangoFilterBackend, OrderingFilter from rest_framework.filters import DjangoFilterBackend, OrderingFilter
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView
from course_discovery.apps.api import filters from course_discovery.apps.api import filters
from course_discovery.apps.api import serializers from course_discovery.apps.api import serializers
...@@ -37,8 +38,6 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun, Part ...@@ -37,8 +38,6 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun, Part
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
RESULT_COUNT = 3
def get_query_param(request, name): def get_query_param(request, name):
""" """
...@@ -680,28 +679,24 @@ class AggregateSearchViewSet(BaseHaystackViewSet): ...@@ -680,28 +679,24 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchViewSet(BaseHaystackViewSet): class TypeaheadSearchView(APIView):
""" """
Typeahead for courses and programs. Typeahead for courses and programs.
""" """
RESULT_COUNT = 3
permission_classes = (IsAuthenticated,)
serializer_class = serializers.TypeaheadSearchSerializer def get_results(self, query):
index_models = (CourseRun, Program,) query = '*{}*'.format(query.lower())
course_runs = SearchQuerySet().models(CourseRun).raw_search(query)[:self.RESULT_COUNT]
programs = SearchQuerySet().models(Program).raw_search(query)[:self.RESULT_COUNT]
return course_runs, programs
def list(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
response = super(TypeaheadSearchViewSet, self).list(request, *args, **kwargs) query = request.query_params.get('q')
results = response.data['results'] if not query:
course_runs, programs = [], [] raise ParseError("The 'q' querystring parameter is required for searching.")
for item in results: course_runs, programs = self.get_results(query)
# Items are grouped by content type so we don't need it in the response data = {'course_runs': course_runs, 'programs': programs}
item_type = item.pop('content_type', None) serializer = serializers.TypeaheadSearchSerializer(data)
programs_length = len(programs) return Response(serializer.data, status=status.HTTP_200_OK)
course_run_length = len(course_runs)
if item_type == 'courserun' and course_run_length < RESULT_COUNT:
course_runs.append(item)
elif item_type == 'program' and programs_length < RESULT_COUNT:
programs.append(item)
elif programs_length == RESULT_COUNT and course_run_length == RESULT_COUNT:
break
response.data = {'course_runs': course_runs, 'programs': programs}
return response
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment