Commit 17f4da7e by Matthew Piatetsky

Enable configuration of synonyms for search and typeahead

ECOM-4959
parent 5d278776
......@@ -54,6 +54,46 @@ class LoginMixin:
self.client.login(username=self.user.username, password=USER_PASSWORD)
class SynonymTestMixin:
def test_org_synonyms(self):
""" Test that synonyms work for organization names """
title = 'UniversityX'
authoring_organizations = [OrganizationFactory(name='University')]
CourseRunFactory(
title=title,
course__partner=self.partner,
authoring_organizations=authoring_organizations
)
ProgramFactory(title=title, partner=self.partner, authoring_organizations=authoring_organizations)
response1 = self.process_response({'q': title})
response2 = self.process_response({'q': 'University'})
self.assertDictEqual(response1, response2)
def test_title_synonyms(self):
""" Test that synonyms work for terms in the title """
CourseRunFactory(title='HTML', course__partner=self.partner)
ProgramFactory(title='HTML', partner=self.partner)
response1 = self.process_response({'q': 'HTML5'})
response2 = self.process_response({'q': 'HTML'})
self.assertDictEqual(response1, response2)
def test_special_character_synonyms(self):
""" Test that synonyms work with special characters (non ascii) """
ProgramFactory(title='spanish', partner=self.partner)
response1 = self.process_response({'q': 'spanish'})
response2 = self.process_response({'q': 'español'})
self.assertDictEqual(response1, response2)
def test_stemmed_synonyms(self):
""" Test that synonyms work with stemming from the snowball analyzer """
title = 'Running'
ProgramFactory(title=title, partner=self.partner)
response1 = self.process_response({'q': 'running'})
response2 = self.process_response({'q': 'jogging'})
self.assertDictEqual(response1, response2)
class DefaultPartnerMixin:
def setUp(self):
super(DefaultPartnerMixin, self).setUp()
......@@ -67,7 +107,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
faceted_path = reverse('api:v1:search-course_runs-facets')
list_path = reverse('api:v1:search-course_runs-list')
def get_search_response(self, query=None, faceted=False):
def get_response(self, query=None, faceted=False):
qs = ''
if query:
......@@ -77,11 +117,16 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
url = '{path}?{qs}'.format(path=path, qs=qs)
return self.client.get(url)
def process_response(self, response):
response = self.get_response(response).json()
self.assertTrue(response['objects']['count'])
return response['objects']
@ddt.data(True, False)
def test_authentication(self, faceted):
""" Verify the endpoint requires authentication. """
self.client.logout()
response = self.get_search_response(faceted=faceted)
response = self.get_response(faceted=faceted)
self.assertEqual(response.status_code, 403)
def test_search(self):
......@@ -105,7 +150,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
# Generate data that should be indexed and returned by the query
course_run = CourseRunFactory(course__partner=self.partner, course__title='Software Testing',
status=CourseRunStatus.Published)
response = self.get_search_response('software', faceted=faceted)
response = self.get_response('software', faceted=faceted)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
......@@ -149,7 +194,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(days=61),
end=now + datetime.timedelta(days=90), status=CourseRunStatus.Published)
response = self.get_search_response(faceted=True)
response = self.get_response(faceted=True)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
......@@ -228,7 +273,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory(courses=course_list, status=ProgramStatus.Active, excluded_course_runs=excluded_course_run_list)
with self.assertNumQueries(6):
response = self.get_search_response('software', faceted=False)
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
response_data = response.json()
......@@ -251,7 +296,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory(courses=[course_run.course], status=program_status)
with self.assertNumQueries(8):
response = self.get_search_response('software', faceted=False)
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
......@@ -268,15 +313,24 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
@ddt.ddt
class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin, APITestCase):
class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets')
def get_search_response(self, querystring=None):
querystring = querystring or {}
qs = urllib.parse.urlencode(querystring)
def get_response(self, query=None):
qs = ''
if query:
qs = urllib.parse.urlencode(query)
url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url)
def process_response(self, response):
response = self.get_response(response).json()
self.assertTrue(response['objects']['count'])
return response['objects']
def test_results_only_include_published_objects(self):
""" Verify the search results only include items with status set to 'Published'. """
# These items should NOT be in the results
......@@ -286,7 +340,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
course_run = CourseRunFactory(course__partner=self.partner, status=CourseRunStatus.Published)
program = ProgramFactory(partner=self.partner, status=ProgramStatus.Active)
response = self.get_search_response()
response = self.get_response()
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(
......@@ -301,7 +355,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
self.assertEqual(CourseRun.objects.get(hidden=True), hidden_run)
response = self.get_search_response()
response = self.get_response()
data = json.loads(response.content.decode('utf-8'))
self.assertEqual(
data['objects']['results'],
......@@ -321,7 +375,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
self.assertNotEqual(other_program.partner.short_code, self.partner.short_code)
self.assertNotEqual(other_course_run.course.partner.short_code, self.partner.short_code)
response = self.get_search_response()
response = self.get_response()
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(
......@@ -330,7 +384,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
)
# Filter results by partner
response = self.get_search_response({'partner': other_partner.short_code})
response = self.get_response({'partner': other_partner.short_code})
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'],
......@@ -342,7 +396,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
course_run = CourseRunFactory(course__partner=self.partner, status=CourseRunStatus.Published)
program = ProgramFactory(partner=self.partner, status=ProgramStatus.Active)
response = self.get_search_response({'q': '', 'content_type': ['courserun', 'program']})
response = self.get_response({'q': '', 'content_type': ['courserun', 'program']})
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content.decode('utf-8'))
self.assertListEqual(response_data['objects']['results'],
......@@ -358,7 +412,7 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
upcoming = CourseRunFactory(course__partner=self.partner, start=now + datetime.timedelta(weeks=4))
course_run_keys = [course_run.key for course_run in [archived, current, starting_soon, upcoming]]
response = self.get_search_response({"ordering": ordering})
response = self.get_response({"ordering": ordering})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data['objects']['count'], 4)
......@@ -368,24 +422,28 @@ class AggregateSearchViewSet(DefaultPartnerMixin, SerializationMixin, LoginMixin
class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-typeahead')
def get_typeahead_response(self, query=None, partner=None):
qs = ''
query_dict = {'q': query, 'partner': partner or self.partner.short_code}
if query:
qs = urllib.parse.urlencode(query_dict)
def get_response(self, query=None, partner=None):
query_dict = query or {}
query_dict.update({'partner': partner or self.partner.short_code})
qs = urllib.parse.urlencode(query_dict)
url = '{path}?{qs}'.format(path=self.path, qs=qs)
return self.client.get(url)
def process_response(self, response):
response = self.get_response(response).json()
self.assertTrue(response['course_runs'] or response['programs'])
return response
def test_typeahead(self):
""" Test typeahead response. """
title = "Python"
course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title)
response = self.get_response({'q': title})
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
......@@ -398,7 +456,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
for i in range(RESULT_COUNT + 1):
CourseRunFactory(title="{}{}".format(title, i), course__partner=self.partner)
ProgramFactory(title="{}{}".format(title, i), status=ProgramStatus.Active, partner=self.partner)
response = self.get_typeahead_response(title)
response = self.get_response({'q': title})
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertEqual(len(response_data['course_runs']), RESULT_COUNT)
......@@ -417,7 +475,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
title=title, authoring_organizations=authoring_organizations,
status=ProgramStatus.Active, partner=self.partner
)
response = self.get_typeahead_response(title)
response = self.get_response({'q': title})
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertDictEqual(response_data, {'course_runs': [self.serialize_course_run(course_run)],
......@@ -429,7 +487,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
course_run = CourseRunFactory(title=title, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
query = "Data Sci"
response = self.get_typeahead_response(query)
response = self.get_response({'q': query})
self.assertEqual(response.status_code, 200)
response_data = response.json()
expected_response_data = {
......@@ -441,23 +499,25 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
def test_unpublished_and_hidden_courses(self):
""" Verify that typeahead does not return unpublished or hidden courses
or programs that are not active. """
title = "Supply Chain"
title = "supply"
course_run = CourseRunFactory(title=title, course__partner=self.partner)
CourseRunFactory(title=title + "_unpublished", status=CourseRunStatus.Unpublished, course__partner=self.partner)
CourseRunFactory(title=title + "_hidden", hidden=True, course__partner=self.partner)
CourseRunFactory(title=title + "unpublished", status=CourseRunStatus.Unpublished, course__partner=self.partner)
CourseRunFactory(title=title + "hidden", hidden=True, course__partner=self.partner)
program = ProgramFactory(title=title, status=ProgramStatus.Active, partner=self.partner)
ProgramFactory(title=title + "_unpublished", status=ProgramStatus.Unpublished, partner=self.partner)
query = "Supply"
response = self.get_typeahead_response(query)
ProgramFactory(title=title + "unpublished", status=ProgramStatus.Unpublished, partner=self.partner)
query = "suppl"
response = self.get_response({'q': query})
self.assertEqual(response.status_code, 200)
response_data = response.json()
expected_response_data = {'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]}
expected_response_data = {
'course_runs': [self.serialize_course_run(course_run)],
'programs': [self.serialize_program(program)]
}
self.assertDictEqual(response_data, expected_response_data)
def test_exception(self):
""" Verify the view raises an error if the 'q' query string parameter is not provided. """
response = self.get_typeahead_response()
response = self.get_response()
self.assertEqual(response.status_code, 400)
self.assertEqual(response.data, ["The 'q' querystring parameter is required for searching."])
......@@ -468,7 +528,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
program = ProgramFactory(authoring_organizations=authoring_organizations, partner=self.partner)
partial_key = authoring_organizations[0].key[0:5]
response = self.get_typeahead_response(partial_key)
response = self.get_response({'q': partial_key})
self.assertEqual(response.status_code, 200)
expected = {
'course_runs': [self.serialize_course_run(course_run)],
......@@ -500,7 +560,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
title='MIT Testing2',
partner=self.partner
)
response = self.get_typeahead_response('mit')
response = self.get_response({'q': 'mit'})
self.assertEqual(response.status_code, 200)
expected = {
'course_runs': [self.serialize_course_run(mit_run),
......@@ -523,7 +583,7 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
title=title, partner=partner,
status=ProgramStatus.Active
))
response = self.get_typeahead_response('partner', 'edx')
response = self.get_response({'q': 'partner'}, 'edx')
self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0]
edx_program = programs[0]
......
......@@ -3,6 +3,8 @@ import logging
from django.conf import settings
from course_discovery.settings.process_synonyms import get_synonyms
logger = logging.getLogger(__name__)
......@@ -19,6 +21,7 @@ class ElasticsearchUtils(object):
timestamp = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
index = '{alias}_{timestamp}'.format(alias=alias, timestamp=timestamp)
index_settings = settings.ELASTICSEARCH_INDEX_SETTINGS
index_settings['settings']['analysis']['filter']['synonym']['synonyms'] = get_synonyms(es)
es.indices.create(index=index, body=index_settings)
logger.info('...index [%s] created.', index)
......
......@@ -93,14 +93,21 @@ class ConfigurableElasticBackend(ElasticsearchSearchBackend):
def build_schema(self, fields):
content_field_name, mapping = super().build_schema(fields)
# Fields default to snowball analyzer, this keeps snowball functionality, but adds synonym functionality
snowball_with_synonyms = 'snowball_with_synonyms'
for field, value in mapping.items():
if value.get('analyzer') == 'snowball':
self.specify_analyzers(mapping=mapping, field=field,
index_analyzer=snowball_with_synonyms,
search_analyzer=snowball_with_synonyms)
# Use the ngram analyzer as the index_analyzer and the lowercase analyzer as the search_analyzer
# This is necessary to support partial searches/typeahead
# If we used ngram analyzer for both, then 'running' would get split into ngrams like "ing"
# and all words containing ing would come back in typeahead.
self.specify_analyzers(mapping=mapping, field='title_autocomplete',
index_analyzer='ngram_analyzer', search_analyzer='lowercase')
index_analyzer='ngram_analyzer', search_analyzer=snowball_with_synonyms)
self.specify_analyzers(mapping=mapping, field='authoring_organizations_autocomplete',
index_analyzer='ngram_analyzer', search_analyzer='lowercase')
index_analyzer='ngram_analyzer', search_analyzer=snowball_with_synonyms)
return (content_field_name, mapping)
......
......@@ -5,6 +5,9 @@ from django.conf import settings
from haystack import connections as haystack_connections
from haystack.management.commands.update_index import Command as HaystackCommand
from course_discovery.settings.process_synonyms import get_synonyms
logger = logging.getLogger(__name__)
......@@ -86,5 +89,6 @@ class Command(HaystackCommand):
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d_%H%M%S')
index_name = '{alias}_{timestamp}'.format(alias=prefix, timestamp=timestamp)
index_settings = settings.ELASTICSEARCH_INDEX_SETTINGS
index_settings['settings']['analysis']['filter']['synonym']['synonyms'] = get_synonyms(backend.conn)
backend.conn.indices.create(index=index_name, body=index_settings)
return index_name
......@@ -377,42 +377,48 @@ ELASTICSEARCH_INDEX_SETTINGS = {
'type': 'custom',
'tokenizer': 'keyword',
'filter': [
'lowercase'
'lowercase',
'synonym',
]
},
'ngram_analyzer': {
'type':'custom',
'snowball_with_synonyms': {
'type': 'custom',
'filter': [
'haystack_ngram',
'lowercase'
'standard',
'lowercase',
'snowball',
'synonym'
],
'tokenizer': 'standard'
},
'edgengram_analyzer': {
'type': 'custom',
'ngram_analyzer': {
'type':'custom',
'filter': [
'haystack_edgengram',
'lowercase'
'lowercase',
'haystack_ngram',
'synonym',
],
'tokenizer': 'standard'
'tokenizer': 'keyword'
}
},
'filter': {
'haystack_edgengram': {
'type': 'edgeNGram',
'min_gram': 2,
'max_gram': 15
},
'haystack_ngram': {
'type': 'nGram',
'min_gram': 2,
'max_gram': 15
'max_gram': 22
},
'synonym' : {
'type': 'synonym',
'ignore_case': 'true',
'synonyms': []
}
}
}
}
}
SYNONYMS_MODULE = 'course_discovery.settings.synonyms'
# Haystack configuration (http://django-haystack.readthedocs.io/en/v2.5.0/settings.html)
HAYSTACK_ITERATOR_LOAD_PER_QUERY = 200
......
from functools import lru_cache
import importlib
from django.conf import settings
def process_synonyms(es, synonyms):
"""Convert synonyms to analyzed form with snowball analyzer.
This method takes list of synonyms in the form 'running, jogging',
applies the snowball analyzer and returns a list of synonyms in the format 'run, jog'.
Attributes:
es (client): client for making requests to es
synonyms (list): list of synonyms (each synonym group is a comma separated string)
"""
processed_synonyms = []
for line in synonyms:
processed_line = []
for synonym in line:
response = es.indices.analyze(text=synonym, analyzer='snowball')
synonym_tokens = ' '.join([item['token'] for item in response['tokens']])
processed_line.append(synonym_tokens)
processed_line = ','.join(processed_line)
processed_synonyms.append(processed_line)
return processed_synonyms
def get_synonym_lines_from_file():
synonyms_module = importlib.import_module(settings.SYNONYMS_MODULE)
return synonyms_module.SYNONYMS
@lru_cache()
def get_synonyms(es):
synonyms = get_synonym_lines_from_file()
synonyms = process_synonyms(es, synonyms)
return synonyms
......@@ -9,6 +9,8 @@ HAYSTACK_CONNECTIONS = {
},
}
SYNONYMS_MODULE = 'course_discovery.settings.test_synonyms'
EDX_DRF_EXTENSIONS = {
'OAUTH2_USER_INFO_URL': 'http://example.com/oauth2/user_info',
}
......
# Note: Do not use synonyms with punctuation, search and typeahead do not yet fully support punctuation
SYNONYMS = [
# Organizations
['ACCA', 'ACCA', 'ACCAx'],
['ACLU', 'American Civil Liberties Union'],
['Berkeley', 'UC BerkeleyX', 'UCBerkeleyX'],
['Georgia Institute of Technology', 'Georgia Tech', 'GTx'],
['Instituto Tecnologico y De Estudios Superiores De Monterrey', 'Monterrey', 'TecdeMonterreyX'],
['Microsoft', 'MicrosoftX', 'msft'],
['MIT', 'MITx'],
['New York Institute of Finance', 'NYIF', 'NYIFx'],
['The University of Michigan', 'MichiganX', 'UMichiganX', 'U Michigan'],
['The University of Texas System', 'UTx'],
['The University of California San Diego', 'UC San DiegoX', 'UCSanDiegoX'],
['The Disque Foundation', 'Save A LifeX', 'SaveALifeX'],
['University of Pennsylvania', 'PennX', 'UPennX', 'UPenn'],
['Universitat Politècnica de València', 'València', 'Valencia'],
['Wharton', 'WhartonX'],
# Common Mispellings
['cs50x', 'cs50'],
['ilets', 'ielts'],
['phyton', 'python'],
['toefl', 'tofel', 'toelf'],
# Subjects
['a11y', 'accessibility'],
['bi', 'business intelligence'],
['bme', 'biomedical engineering'],
['computer science', 'cs'],
['econ', 'economics'],
['ee', 'electrical engineering'],
['español', 'espanol', 'spanish'],
['français', 'francais', 'french'],
['it', 'information technology'],
['mis', 'management information systems'],
['psych', 'psychology'],
['seo', 'search engine optimization'],
['ux', 'user experience'],
# Other Terms
['autocad', 'auto cad', 'cad'],
['aws', 'amazon web services'],
['css', 'cascading style sheets'],
['excel', 'microsoft excel', 'msft excel'],
['hr', 'human resources'],
['HTML5', 'HTML'],
['iot', 'internet of things'],
['javascript', 'js', 'java script', 'react', 'typescript', 'jquery'],
['management', 'mgmt'],
['os', 'operating system'],
['photo', 'photography'],
['vb', 'visual basic'],
['vba', 'excel'],
['usa', 'united states of america', 'murika'],
]
# Test Synonyms
SYNONYMS = [
["University", "UniversityX"],
["HTML5", "HTML"],
["running", "jogging"],
["spanish", "español"]
]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment