Commit 17f4da7e by Matthew Piatetsky

Enable configuration of synonyms for search and typeahead

ECOM-4959
parent 5d278776
...@@ -54,6 +54,46 @@ class LoginMixin: ...@@ -54,6 +54,46 @@ class LoginMixin:
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
class SynonymTestMixin:
def test_org_synonyms(self):
""" Test that synonyms work for organization names """
title = 'UniversityX'
authoring_organizations = [OrganizationFactory(name='University')]
CourseRunFactory(
title=title,
course__partner=self.partner,
authoring_organizations=authoring_organizations
)
ProgramFactory(title=title, partner=self.partner, authoring_organizations=authoring_organizations)
response1 = self.process_response({'q': title})
response2 = self.process_response({'q': 'University'})
self.assertDictEqual(response1, response2)
def test_title_synonyms(self):
""" Test that synonyms work for terms in the title """
CourseRunFactory(title='HTML', course__partner=self.partner)
ProgramFactory(title='HTML', partner=self.partner)
response1 = self.process_response({'q': 'HTML5'})
response2 = self.process_response({'q': 'HTML'})
self.assertDictEqual(response1, response2)
def test_special_character_synonyms(self):
""" Test that synonyms work with special characters (non ascii) """
ProgramFactory(title='spanish', partner=self.partner)
response1 = self.process_response({'q': 'spanish'})
response2 = self.process_response({'q': 'español'})
self.assertDictEqual(response1, response2)
def test_stemmed_synonyms(self):
""" Test that synonyms work with stemming from the snowball analyzer """
title = 'Running'
ProgramFactory(title=title, partner=self.partner)
response1 = self.process_response({'q': 'running'})
response2 = self.process_response({'q': 'jogging'})
self.assertDictEqual(response1, response2)
class DefaultPartnerMixin: class DefaultPartnerMixin:
def setUp(self): def setUp(self):
super(DefaultPartnerMixin, self).setUp() super(DefaultPartnerMixin, self).setUp()
...@@ -67,7 +107,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -67,7 +107,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
faceted_path = reverse('api:v1:search-course_runs-facets') faceted_path = reverse('api:v1:search-course_runs-facets')
list_path = reverse('api:v1:search-course_runs-list') list_path = reverse('api:v1:search-course_runs-list')
def get_search_response(self, query=None, faceted=False): def get_response(self, query=None, faceted=False):
qs = '' qs = ''
if query: if query:
...@@ -77,11 +117,16 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -77,11 +117,16 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
url = '{path}?{qs}'.format(path=path, qs=qs) url = '{path}?{qs}'.format(path=path, qs=qs)
return self.client.get(url) return self.client.get(url)
def process_response(self, response):
response = self.get_response(response).json()
self.assertTrue(response['objects']['count'])
return response['objects']
@ddt.data(True, False) @ddt.data(True, False)
def test_authentication(self, faceted): def test_authentication(self, faceted):
""" Verify the endpoint requires authentication. """ """ Verify the endpoint requires authentication. """
self.client.logout() self.client.logout()
response = self.get_search_response(faceted=faceted) response = self.get_response(faceted=faceted)
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
def test_search(self): def test_search(self):
...@@ -105,7 +150,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -105,7 +150,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
# Generate data that should be indexed and returned by the query # Generate data that should be indexed and returned by the query
course_run = CourseRunFactory(course__partner=self.partner, course__title='Software Testing', course_run = CourseRunFactory(course__partner=self.partner, course__title='Software Testing',
status=CourseRunStatus.Published) status=CourseRunStatus.Published)
response = self.get_search_response('software', faceted=faceted) response = self.get_response('software', faceted=faceted)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
...@@ -149,7 +194,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -149,7 +194,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(days=61), upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(days=61),
end=now + datetime.timedelta(days=90), status=CourseRunStatus.Published) end=now + datetime.timedelta(days=90), status=CourseRunStatus.Published)
response = self.get_search_response(faceted=True) response = self.get_response(faceted=True)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
...@@ -228,7 +273,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -228,7 +273,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory(courses=course_list, status=ProgramStatus.Active, excluded_course_runs=excluded_course_run_list) ProgramFactory(courses=course_list, status=ProgramStatus.Active, excluded_course_runs=excluded_course_run_list)
with self.assertNumQueries(6): with self.assertNumQueries(6):
response = self.get_search_response('software', faceted=False) response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
...@@ -251,7 +296,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -251,7 +296,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory(courses=[course_run.course], status=program_status) ProgramFactory(courses=[course_run.course], status=program_status)
with self.assertNumQueries(8): with self.assertNumQueries(8):
response = self.get_search_response('software', faceted=False) response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
...@@ -268,15 +313,24 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -268,15 +313,24 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
@ddt.ddt @ddt.ddt
class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase): class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets') path = reverse('api:v1:search-all-facets')
def get_search_response(self, querystring=None): def get_response(self, query=None):
querystring = querystring or {} qs = ''
qs = urllib.parse.urlencode(querystring)
if query:
qs = urllib.parse.urlencode(query)
url = '{path}?{qs}'.format(path=self.path, qs=qs) url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url) return self.client.get(url)
def process_response(self, response):
response = self.get_response(response).json()
self.assertTrue(response['objects']['count'])
return response['objects']
def test_results_only_include_published_objects(self): def test_results_only_include_published_objects(self):
""" Verify the search results only include items with status set to 'Published'. """ """ Verify the search results only include items with status set to 'Published'. """
# These items should NOT be in the results # These items should NOT be in the results
...@@ -286,7 +340,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -286,7 +340,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
course_run = CourseRunFactory(course__partner=self.partner, status=CourseRunStatus.Published) course_run = CourseRunFactory(course__partner=self.partner, status=CourseRunStatus.Published)
program = ProgramFactory(partner=self.partner, status=ProgramStatus.Active) program = ProgramFactory(partner=self.partner, status=ProgramStatus.Active)
response = self.get_search_response() response = self.get_response()
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual( self.assertListEqual(
...@@ -301,7 +355,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -301,7 +355,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
self.assertEqual(CourseRun.objects.get(hidden=True), hidden_run) self.assertEqual(CourseRun.objects.get(hidden=True), hidden_run)
response = self.get_search_response() response = self.get_response()
data = json.loads(response.content.decode('utf-8')) data = json.loads(response.content.decode('utf-8'))
self.assertEqual( self.assertEqual(
data['objects']['results'], data['objects']['results'],
...@@ -321,7 +375,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -321,7 +375,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
self.assertNotEqual(other_program.partner.short_code, self.partner.short_code) self.assertNotEqual(other_program.partner.short_code, self.partner.short_code)
self.assertNotEqual(other_course_run.course.partner.short_code, self.partner.short_code) self.assertNotEqual(other_course_run.course.partner.short_code, self.partner.short_code)
response = self.get_search_response() response = self.get_response()
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual( self.assertListEqual(
...@@ -330,7 +384,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -330,7 +384,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
) )
# Filter results by partner # Filter results by partner
response = self.get_search_response({'partner': other_partner.short_code}) response = self.get_response({'partner': other_partner.short_code})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'], self.assertListEqual(response_data['objects']['results'],
...@@ -342,7 +396,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -342,7 +396,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
course_run = CourseRunFactory(course__partner=self.partner, status=CourseRunStatus.Published) course_run = CourseRunFactory(course__partner=self.partner, status=CourseRunStatus.Published)
program = ProgramFactory(partner=self.partner, status=ProgramStatus.Active) program = ProgramFactory(partner=self.partner, status=ProgramStatus.Active)
response = self.get_search_response({'q': '', 'content_type': ['courserun', 'program']}) response = self.get_response({'q': '', 'content_type': ['courserun', 'program']})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8')) response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'], self.assertListEqual(response_data['objects']['results'],
...@@ -358,7 +412,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -358,7 +412,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(weeks=4)) upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(weeks=4))
course_run_keys = [course_run.key for course_run in [archived, current, starting_soon, upcoming]] course_run_keys = [course_run.key for course_run in [archived, current, starting_soon, upcoming]]
response = self.get_search_response({"ordering": ordering}) response = self.get_response({"ordering": ordering})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data['objects']['count'], 4) self.assertEqual(response.data['objects']['count'], 4)
...@@ -368,24 +422,28 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin ...@@ -368,24 +422,28 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin, class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase): SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-typeahead') path = reverse('api:v1:search-typeahead')
def get_typeahead_response(self, query=None, partner=None): def get_response(self, query=None, partner=None):
qs = '' query_dict = query or {}
query_dict = {'q': query, 'partner': partner or self.partner.short_code} query_dict.update({'partner': partner or self.partner.short_code})
if query: qs = urllib.parse.urlencode(query_dict)
qs = urllib.parse.urlencode(query_dict)
url = '{path}?{qs}'.format(path=self.path, qs=qs) url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url) return self.client.get(url)
def process_response(self, response):
response = self.get_response(response).json()
self.assertTrue(response['course_runs'] or response['programs'])
return response
def test_typeahead(self): def test_typeahead(self):
""" Test typeahead response. """ """ Test typeahead response. """
title = "Python" title = "Python"
course_run = CourseRunFactory(title=title, course__partner=self.partner) course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner) program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title) response = self.get_response({'q': title})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)], self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
...@@ -398,7 +456,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -398,7 +456,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
for i in range(RESULT_COUNT + 1): for i in range(RESULT_COUNT + 1):
CourseRunFactory(title="{}{}".format(title, i), course__partner=self.partner) CourseRunFactory(title="{}{}".format(title, i), course__partner=self.partner)
ProgramFactory(title="{}{}".format(title, i), status=ProgramStatus.Active, partner=self.partner) ProgramFactory(title="{}{}".format(title, i), status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title) response = self.get_response({'q': title})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
self.assertEqual(len(response_data['course_runs']), RESULT_COUNT) self.assertEqual(len(response_data['course_runs']), RESULT_COUNT)
...@@ -417,7 +475,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -417,7 +475,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
title=title, authoring_organizations=authoring_organizations, title=title, authoring_organizations=authoring_organizations,
status=ProgramStatus.Active, partner=self.partner status=ProgramStatus.Active, partner=self.partner
) )
response = self.get_typeahead_response(title) response = self.get_response({'q': title})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)], self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
...@@ -429,7 +487,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -429,7 +487,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
course_run = CourseRunFactory(title=title, course__partner=self.partner) course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner) program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
query = "Data Sci" query = "Data Sci"
response = self.get_typeahead_response(query) response = self.get_response({'q': query})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
expected_response_data = { expected_response_data = {
...@@ -441,23 +499,25 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -441,23 +499,25 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
def test_unpublished_and_hidden_courses(self): def test_unpublished_and_hidden_courses(self):
""" Verify that typeahead does not return unpublished or hidden courses """ Verify that typeahead does not return unpublished or hidden courses
or programs that are not active. """ or programs that are not active. """
title = "Supply Chain" title = "supply"
course_run = CourseRunFactory(title=title, course__partner=self.partner) course_run = CourseRunFactory(title=title, course__partner=self.partner)
CourseRunFactory(title=title + "_unpublished", status=CourseRunStatus.Unpublished, course__partner=self.partner) CourseRunFactory(title=title + "unpublished", status=CourseRunStatus.Unpublished, course__partner=self.partner)
CourseRunFactory(title=title + "_hidden", hidden=True, course__partner=self.partner) CourseRunFactory(title=title + "hidden", hidden=True, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner) program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
ProgramFactory(title=title + "_unpublished", status=ProgramStatus.Unpublished, partner=self.partner) ProgramFactory(title=title + "unpublished", status=ProgramStatus.Unpublished, partner=self.partner)
query = "Supply" query = "suppl"
response = self.get_typeahead_response(query) response = self.get_response({'q': query})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response_data = response.json() response_data = response.json()
expected_response_data = {'course_runs': [self.serialize_course_run(course_run)], expected_response_data = {
'programs': [self.serialize_program(program)]} 'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]
}
self.assertDictEqual(response_data, expected_response_data) self.assertDictEqual(response_data, expected_response_data)
def test_exception(self): def test_exception(self):
""" Verify the view raises an error if the 'q' query string parameter is not provided. """ """ Verify the view raises an error if the 'q' query string parameter is not provided. """
response = self.get_typeahead_response() response = self.get_response()
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(response.data, ["The 'q' querystring parameter is required for searching."]) self.assertEqual(response.data, ["The 'q' querystring parameter is required for searching."])
...@@ -468,7 +528,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -468,7 +528,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
program = ProgramFactory(authoring_organizations=authoring_organizations, partner=self.partner) program = ProgramFactory(authoring_organizations=authoring_organizations, partner=self.partner)
partial_key = authoring_organizations[0].key[0:5] partial_key = authoring_organizations[0].key[0:5]
response = self.get_typeahead_response(partial_key) response = self.get_response({'q': partial_key})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
expected = { expected = {
'course_runs': [self.serialize_course_run(course_run)], 'course_runs': [self.serialize_course_run(course_run)],
...@@ -500,7 +560,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -500,7 +560,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
title='MIT Testing2', title='MIT Testing2',
partner=self.partner partner=self.partner
) )
response = self.get_typeahead_response('mit') response = self.get_response({'q': 'mit'})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
expected = { expected = {
'course_runs': [self.serialize_course_run(mit_run), 'course_runs': [self.serialize_course_run(mit_run),
...@@ -523,7 +583,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -523,7 +583,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
title=title, partner=partner, title=title, partner=partner,
status=ProgramStatus.Active status=ProgramStatus.Active
)) ))
response = self.get_typeahead_response('partner', 'edx') response = self.get_response({'q': 'partner'}, 'edx')
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0] edx_course_run = course_runs[0]
edx_program = programs[0] edx_program = programs[0]
......
...@@ -3,6 +3,8 @@ import logging ...@@ -3,6 +3,8 @@ import logging
from django.conf import settings from django.conf import settings
from course_discovery.settings.process_synonyms import get_synonyms
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,6 +21,7 @@ class ElasticsearchUtils(object): ...@@ -19,6 +21,7 @@ class ElasticsearchUtils(object):
timestamp = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S") timestamp = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
index = '{alias}_{timestamp}'.format(alias=alias, timestamp=timestamp) index = '{alias}_{timestamp}'.format(alias=alias, timestamp=timestamp)
index_settings = settings.ELASTICSEARCH_INDEX_SETTINGS index_settings = settings.ELASTICSEARCH_INDEX_SETTINGS
index_settings['settings']['analysis']['filter']['synonym']['synonyms'] = get_synonyms(es)
es.indices.create(index=index, body=index_settings) es.indices.create(index=index, body=index_settings)
logger.info('...index [%s] created.', index) logger.info('...index [%s] created.', index)
......
...@@ -93,14 +93,21 @@ class ConfigurableElasticBackend(ElasticsearchSearchBackend): ...@@ -93,14 +93,21 @@ class ConfigurableElasticBackend(ElasticsearchSearchBackend):
def build_schema(self, fields): def build_schema(self, fields):
content_field_name, mapping = super().build_schema(fields) content_field_name, mapping = super().build_schema(fields)
# Fields default to snowball analyzer, this keeps snowball functionality, but adds synonym functionality
snowball_with_synonyms = 'snowball_with_synonyms'
for field, value in mapping.items():
if value.get('analyzer') == 'snowball':
self.specify_analyzers(mapping=mapping, field=field,
index_analyzer=snowball_with_synonyms,
search_analyzer=snowball_with_synonyms)
# Use the ngram analyzer as the index_analyzer and the lowercase analyzer as the search_analyzer # Use the ngram analyzer as the index_analyzer and the lowercase analyzer as the search_analyzer
# This is necessary to support partial searches/typeahead # This is necessary to support partial searches/typeahead
# If we used ngram analyzer for both, then 'running' would get split into ngrams like "ing" # If we used ngram analyzer for both, then 'running' would get split into ngrams like "ing"
# and all words containing ing would come back in typeahead. # and all words containing ing would come back in typeahead.
self.specify_analyzers(mapping=mapping, field='title_autocomplete', self.specify_analyzers(mapping=mapping, field='title_autocomplete',
index_analyzer='ngram_analyzer', search_analyzer='lowercase') index_analyzer='ngram_analyzer', search_analyzer=snowball_with_synonyms)
self.specify_analyzers(mapping=mapping, field='authoring_organizations_autocomplete', self.specify_analyzers(mapping=mapping, field='authoring_organizations_autocomplete',
index_analyzer='ngram_analyzer', search_analyzer='lowercase') index_analyzer='ngram_analyzer', search_analyzer=snowball_with_synonyms)
return (content_field_name, mapping) return (content_field_name, mapping)
......
...@@ -5,6 +5,9 @@ from django.conf import settings ...@@ -5,6 +5,9 @@ from django.conf import settings
from haystack import connections as haystack_connections from haystack import connections as haystack_connections
from haystack.management.commands.update_index import Command as HaystackCommand from haystack.management.commands.update_index import Command as HaystackCommand
from course_discovery.settings.process_synonyms import get_synonyms
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -86,5 +89,6 @@ class Command(HaystackCommand): ...@@ -86,5 +89,6 @@ class Command(HaystackCommand):
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d_%H%M%S') timestamp = datetime.datetime.utcnow().strftime('%Y%m%d_%H%M%S')
index_name = '{alias}_{timestamp}'.format(alias=prefix, timestamp=timestamp) index_name = '{alias}_{timestamp}'.format(alias=prefix, timestamp=timestamp)
index_settings = settings.ELASTICSEARCH_INDEX_SETTINGS index_settings = settings.ELASTICSEARCH_INDEX_SETTINGS
index_settings['settings']['analysis']['filter']['synonym']['synonyms'] = get_synonyms(backend.conn)
backend.conn.indices.create(index=index_name, body=index_settings) backend.conn.indices.create(index=index_name, body=index_settings)
return index_name return index_name
...@@ -377,42 +377,48 @@ ELASTICSEARCH_INDEX_SETTINGS = { ...@@ -377,42 +377,48 @@ ELASTICSEARCH_INDEX_SETTINGS = {
'type': 'custom', 'type': 'custom',
'tokenizer': 'keyword', 'tokenizer': 'keyword',
'filter': [ 'filter': [
'lowercase' 'lowercase',
'synonym',
] ]
}, },
'ngram_analyzer': { 'snowball_with_synonyms': {
'type':'custom', 'type': 'custom',
'filter': [ 'filter': [
'haystack_ngram', 'standard',
'lowercase' 'lowercase',
'snowball',
'synonym'
], ],
'tokenizer': 'standard' 'tokenizer': 'standard'
}, },
'edgengram_analyzer': { 'ngram_analyzer': {
'type': 'custom', 'type':'custom',
'filter': [ 'filter': [
'haystack_edgengram', 'lowercase',
'lowercase' 'haystack_ngram',
'synonym',
], ],
'tokenizer': 'standard' 'tokenizer': 'keyword'
} }
}, },
'filter': { 'filter': {
'haystack_edgengram': {
'type': 'edgeNGram',
'min_gram': 2,
'max_gram': 15
},
'haystack_ngram': { 'haystack_ngram': {
'type': 'nGram', 'type': 'nGram',
'min_gram': 2, 'min_gram': 2,
'max_gram': 15 'max_gram': 22
},
'synonym' : {
'type': 'synonym',
'ignore_case': 'true',
'synonyms': []
} }
} }
} }
} }
} }
SYNONYMS_MODULE = 'course_discovery.settings.synonyms'
# Haystack configuration (http://django-haystack.readthedocs.io/en/v2.5.0/settings.html) # Haystack configuration (http://django-haystack.readthedocs.io/en/v2.5.0/settings.html)
HAYSTACK_ITERATOR_LOAD_PER_QUERY = 200 HAYSTACK_ITERATOR_LOAD_PER_QUERY = 200
......
from functools import lru_cache
import importlib
from django.conf import settings
def process_synonyms(es, synonyms):
"""Convert synonyms to analyzed form with snowball analyzer.
This method takes list of synonyms in the form 'running, jogging',
applies the snowball analyzer and returns a list of synonyms in the format 'run, jog'.
Attributes:
es (client): client for making requests to es
synonyms (list): list of synonyms (each synonym group is a comma separated string)
"""
processed_synonyms = []
for line in synonyms:
processed_line = []
for synonym in line:
response = es.indices.analyze(text=synonym, analyzer='snowball')
synonym_tokens = ' '.join([item['token'] for item in response['tokens']])
processed_line.append(synonym_tokens)
processed_line = ','.join(processed_line)
processed_synonyms.append(processed_line)
return processed_synonyms
def get_synonym_lines_from_file():
synonyms_module = importlib.import_module(settings.SYNONYMS_MODULE)
return synonyms_module.SYNONYMS
@lru_cache()
def get_synonyms(es):
synonyms = get_synonym_lines_from_file()
synonyms = process_synonyms(es, synonyms)
return synonyms
...@@ -9,6 +9,8 @@ HAYSTACK_CONNECTIONS = { ...@@ -9,6 +9,8 @@ HAYSTACK_CONNECTIONS = {
}, },
} }
SYNONYMS_MODULE = 'course_discovery.settings.test_synonyms'
EDX_DRF_EXTENSIONS = { EDX_DRF_EXTENSIONS = {
'OAUTH2_USER_INFO_URL': 'http://example.com/oauth2/user_info', 'OAUTH2_USER_INFO_URL': 'http://example.com/oauth2/user_info',
} }
......
# Note: Do not use synonyms with punctuation, search and typeahead do not yet fully support punctuation
SYNONYMS = [
# Organizations
['ACCA', 'ACCA', 'ACCAx'],
['ACLU', 'American Civil Liberties Union'],
['Berkeley', 'UC BerkeleyX', 'UCBerkeleyX'],
['Georgia Institute of Technology', 'Georgia Tech', 'GTx'],
['Instituto Tecnologico y De Estudios Superiores De Monterrey', 'Monterrey', 'TecdeMonterreyX'],
['Microsoft', 'MicrosoftX', 'msft'],
['MIT', 'MITx'],
['New York Institute of Finance', 'NYIF', 'NYIFx'],
['The University of Michigan', 'MichiganX', 'UMichiganX', 'U Michigan'],
['The University of Texas System', 'UTx'],
['The University of California San Diego', 'UC San DiegoX', 'UCSanDiegoX'],
['The Disque Foundation', 'Save A LifeX', 'SaveALifeX'],
['University of Pennsylvania', 'PennX', 'UPennX', 'UPenn'],
['Universitat Politècnica de València', 'València', 'Valencia'],
['Wharton', 'WhartonX'],
# Common Mispellings
['cs50x', 'cs50'],
['ilets', 'ielts'],
['phyton', 'python'],
['toefl', 'tofel', 'toelf'],
# Subjects
['a11y', 'accessibility'],
['bi', 'business intelligence'],
['bme', 'biomedical engineering'],
['computer science', 'cs'],
['econ', 'economics'],
['ee', 'electrical engineering'],
['español', 'espanol', 'spanish'],
['français', 'francais', 'french'],
['it', 'information technology'],
['mis', 'management information systems'],
['psych', 'psychology'],
['seo', 'search engine optimization'],
['ux', 'user experience'],
# Other Terms
['autocad', 'auto cad', 'cad'],
['aws', 'amazon web services'],
['css', 'cascading style sheets'],
['excel', 'microsoft excel', 'msft excel'],
['hr', 'human resources'],
['HTML5', 'HTML'],
['iot', 'internet of things'],
['javascript', 'js', 'java script', 'react', 'typescript', 'jquery'],
['management', 'mgmt'],
['os', 'operating system'],
['photo', 'photography'],
['vb', 'visual basic'],
['vba', 'excel'],
['usa', 'united states of america', 'murika'],
]
# Test Synonyms
SYNONYMS = [
["University", "UniversityX"],
["HTML5", "HTML"],
["running", "jogging"],
["spanish", "español"]
]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment