Commit 06eb2bc4 by Matthew Piatetsky

Add partial term matching to typeahead

ECOM-4738
parent 9e2663c7
......@@ -913,20 +913,6 @@ class CourseRunSearchSerializer(HaystackSerializer):
index_classes = [CourseRunIndex]
class TypeaheadCourseRunSearchSerializer(HaystackSerializer):
additional_details = serializers.SerializerMethodField()
def get_additional_details(self, result):
""" Value of the grey text next to the typeahead result title. """
return result.org
class Meta:
field_aliases = COMMON_SEARCH_FIELD_ALIASES
fields = ['key', 'title', 'content_type']
ignore_fields = COMMON_IGNORED_FIELDS
index_classes = [CourseRunIndex]
class CourseRunFacetSerializer(BaseHaystackFacetSerializer):
serialize_objects = True
......@@ -952,21 +938,6 @@ class ProgramSearchSerializer(HaystackSerializer):
index_classes = [ProgramIndex]
class TypeaheadProgramSearchSerializer(HaystackSerializer):
additional_details = serializers.SerializerMethodField()
def get_additional_details(self, result):
""" Value of the grey text next to the typeahead result title. """
authoring_organizations = [json.loads(org) for org in result.authoring_organization_bodies]
return ', '.join([org['key'] for org in authoring_organizations])
class Meta:
field_aliases = COMMON_SEARCH_FIELD_ALIASES
fields = ['uuid', 'title', 'content_type', 'type']
ignore_fields = COMMON_IGNORED_FIELDS
index_classes = [ProgramIndex]
class ProgramFacetSerializer(BaseHaystackFacetSerializer):
serialize_objects = True
......@@ -990,15 +961,32 @@ class AggregateSearchSerializer(HaystackSerializer):
}
class TypeaheadSearchSerializer(HaystackSerializer):
class TypeaheadCourseRunSearchSerializer(serializers.Serializer):
org = serializers.CharField()
title = serializers.CharField()
key = serializers.CharField()
class Meta:
field_aliases = COMMON_SEARCH_FIELD_ALIASES
fields = COURSE_RUN_SEARCH_FIELDS + PROGRAM_SEARCH_FIELDS
ignore_fields = COMMON_IGNORED_FIELDS
serializers = {
ProgramIndex: TypeaheadProgramSearchSerializer,
CourseRunIndex: TypeaheadCourseRunSearchSerializer,
}
fields = ['key', 'title']
class TypeaheadProgramSearchSerializer(serializers.Serializer):
orgs = serializers.SerializerMethodField()
uuid = serializers.CharField()
title = serializers.CharField()
type = serializers.CharField()
def get_orgs(self, result):
authoring_organizations = [json.loads(org) for org in result.authoring_organization_bodies]
return [org['key'] for org in authoring_organizations]
class Meta:
fields = ['uuid', 'title', 'type']
class TypeaheadSearchSerializer(serializers.Serializer):
course_runs = TypeaheadCourseRunSearchSerializer(many=True)
programs = TypeaheadProgramSearchSerializer(many=True)
class AggregateFacetSearchSerializer(BaseHaystackFacetSerializer):
......
......@@ -1110,8 +1110,7 @@ class TypeaheadCourseRunSearchSerializerTests(TestCase):
expected = {
'key': course_run.key,
'title': course_run.title,
'content_type': 'courserun',
'additional_details': course_run_key.org
'org': course_run_key.org
}
self.assertDictEqual(serialized_course.data, expected)
......@@ -1128,16 +1127,13 @@ class TypeaheadProgramSearchSerializerTests(TestCase):
'uuid': str(program.uuid),
'title': program.title,
'type': program.type.name,
'content_type': 'program',
'additional_details': program.authoring_organizations.first().key
'orgs': list(program.authoring_organizations.all().values_list('key', flat=True))
}
def test_data(self):
authoring_organization = OrganizationFactory()
program = ProgramFactory(authoring_organizations=[authoring_organization])
serialized_program = self.serialize_program(program)
expected = self._create_expected_data(program)
self.assertDictEqual(serialized_program.data, expected)
......@@ -1145,8 +1141,8 @@ class TypeaheadProgramSearchSerializerTests(TestCase):
authoring_organizations = OrganizationFactory.create_batch(3)
program = ProgramFactory(authoring_organizations=authoring_organizations)
serialized_program = self.serialize_program(program)
expected = ', '.join([org.key for org in authoring_organizations])
self.assertEqual(serialized_program.data['additional_details'], expected)
expected = [org.key for org in authoring_organizations]
self.assertEqual(serialized_program.data['orgs'], expected)
def serialize_program(self, program):
""" Serializes the given `Program` as a typeahead result. """
......
......@@ -10,8 +10,7 @@ from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer)
from course_discovery.apps.api.v1.views import RESULT_COUNT
from course_discovery.apps.api.v1.views import TypeaheadSearchView
from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD, PartnerFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
......@@ -33,15 +32,11 @@ class TypeaheadSerializationMixin:
def serialize_course_run(self, course_run):
result = SearchQuerySet().models(CourseRun).filter(key=course_run.key)[0]
data = TypeaheadCourseRunSearchSerializer(result).data
# Items are grouped by content type so we don't need it in the response
data.pop('content_type')
return data
def serialize_program(self, program):
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
data = TypeaheadProgramSearchSerializer(result).data
# Items are grouped by content type so we don't need it in the response
data.pop('content_type')
return data
......@@ -300,27 +295,37 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
[self.serialize_course_run(course_run), self.serialize_program(program)])
class TypeaheadSearchViewSet(TypeaheadSerializationMixin, LoginMixin, APITestCase):
path = reverse('api:v1:search-typeahead-list')
class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, APITestCase):
path = reverse('api:v1:search-typeahead')
def get_typeahead_response(self, query=None):
qs = ''
def get_typeahead_response(self):
return self.client.get(self.path)
if query:
qs = urllib.parse.urlencode({'q': query})
url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url)
def test_typeahead(self):
""" Test typeahead response. """
course_run = CourseRunFactory()
program = ProgramFactory()
response = self.get_typeahead_response()
title = "Python"
course_run = CourseRunFactory(title=title)
program = ProgramFactory(title=title)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]})
def test_typeahead_multiple_results(self):
""" Test typeahead response with max number of course_runs and programs. """
CourseRunFactory.create_batch(RESULT_COUNT + 1)
ProgramFactory.create_batch(RESULT_COUNT + 1)
response = self.get_typeahead_response()
""" Verify the typeahead responses always returns a limited number of results, even if there are more hits. """
RESULT_COUNT = TypeaheadSearchView.RESULT_COUNT
title = "Test"
for i in range(RESULT_COUNT + 1):
CourseRunFactory(title="{}{}".format(title, i))
ProgramFactory(title="{}{}".format(title, i))
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertEqual(len(response_data['course_runs']), RESULT_COUNT)
......@@ -328,11 +333,31 @@ class TypeaheadSearchViewSet(TypeaheadSerializationMixin, LoginMixin, APITestCas
def test_typeahead_multiple_authoring_organizations(self):
""" Test typeahead response with multiple authoring organizations. """
title = "Design"
authoring_organizations = OrganizationFactory.create_batch(3)
course_run = CourseRunFactory(authoring_organizations=authoring_organizations)
program = ProgramFactory(authoring_organizations=authoring_organizations)
response = self.get_typeahead_response()
course_run = CourseRunFactory(title=title, authoring_organizations=authoring_organizations)
program = ProgramFactory(title=title, authoring_organizations=authoring_organizations)
response = self.get_typeahead_response(title)
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]})
def test_partial_term_search(self):
""" Test typeahead response with partial term search. """
title = "Learn Data Science"
course_run = CourseRunFactory(title=title)
program = ProgramFactory(title=title)
query = "Data Sci"
response = self.get_typeahead_response(query)
self.assertEqual(response.status_code, 200)
response_data = response.json()
expected_response_data = {'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]}
self.assertDictEqual(response_data, expected_response_data)
def test_exception(self):
""" Verify the view raises an error if the 'q' query string parameter is not provided. """
response = self.get_typeahead_response()
self.assertEqual(response.status_code, 400)
self.assertDictEqual(response.data, {'detail': 'The \'q\' querystring parameter is required for searching.'})
......@@ -9,6 +9,7 @@ partners_router.register(r'affiliate_window/catalogs', views.AffiliateWindowView
partners_urls = partners_router.urls
urlpatterns = [
url(r'^partners/', include(partners_urls, namespace='partners')),
url(r'search/typeahead', views.TypeaheadSearchView.as_view(), name='search-typeahead')
]
router = routers.SimpleRouter()
......@@ -18,7 +19,6 @@ router.register(r'course_runs', views.CourseRunViewSet, base_name='course_run')
router.register(r'management', views.ManagementViewSet, base_name='management')
router.register(r'programs', views.ProgramViewSet, base_name='program')
router.register(r'search/all', views.AggregateSearchViewSet, base_name='search-all')
router.register(r'search/typeahead', views.TypeaheadSearchViewSet, base_name='search-typeahead')
router.register(r'search/courses', views.CourseSearchViewSet, base_name='search-courses')
router.register(r'search/course_runs', views.CourseRunSearchViewSet, base_name='search-course_runs')
router.register(r'search/programs', views.ProgramSearchViewSet, base_name='search-programs')
......
......@@ -15,13 +15,14 @@ from drf_haystack.viewsets import HaystackViewSet
from dry_rest_permissions.generics import DRYPermissions
from edx_rest_framework_extensions.permissions import IsSuperuser
from haystack.inputs import AutoQuery
from haystack.query import SQ
from haystack.query import SQ, SearchQuerySet
from rest_framework import status, viewsets
from rest_framework.decorators import detail_route, list_route
from rest_framework.exceptions import PermissionDenied, ParseError
from rest_framework.filters import DjangoFilterBackend, OrderingFilter
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from course_discovery.apps.api import filters
from course_discovery.apps.api import serializers
......@@ -37,8 +38,6 @@ from course_discovery.apps.course_metadata.models import Course, CourseRun, Part
logger = logging.getLogger(__name__)
User = get_user_model()
RESULT_COUNT = 3
def get_query_param(request, name):
"""
......@@ -680,28 +679,24 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchViewSet(BaseHaystackViewSet):
class TypeaheadSearchView(APIView):
"""
Typeahead for courses and programs.
"""
RESULT_COUNT = 3
permission_classes = (IsAuthenticated,)
serializer_class = serializers.TypeaheadSearchSerializer
index_models = (CourseRun, Program,)
def get_results(self, query):
query = '*{}*'.format(query.lower())
course_runs = SearchQuerySet().models(CourseRun).raw_search(query)[:self.RESULT_COUNT]
programs = SearchQuerySet().models(Program).raw_search(query)[:self.RESULT_COUNT]
return course_runs, programs
def list(self, request, *args, **kwargs):
response = super(TypeaheadSearchViewSet, self).list(request, *args, **kwargs)
results = response.data['results']
course_runs, programs = [], []
for item in results:
# Items are grouped by content type so we don't need it in the response
item_type = item.pop('content_type', None)
programs_length = len(programs)
course_run_length = len(course_runs)
if item_type == 'courserun' and course_run_length < RESULT_COUNT:
course_runs.append(item)
elif item_type == 'program' and programs_length < RESULT_COUNT:
programs.append(item)
elif programs_length == RESULT_COUNT and course_run_length == RESULT_COUNT:
break
response.data = {'course_runs': course_runs, 'programs': programs}
return response
def get(self, request, *args, **kwargs):
query = request.query_params.get('q')
if not query:
raise ParseError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query)
data = {'course_runs': course_runs, 'programs': programs}
serializer = serializers.TypeaheadSearchSerializer(data)
return Response(serializer.data, status=status.HTTP_200_OK)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment