Commit 28bca8e6 by Bill DeRusha

Merge pull request #9365 from edx/bderusha/edx-search-for-teams-TNL-3013

Add basic elasticsearch search for teams
parents 5d1bd225 973314de
......@@ -233,3 +233,4 @@ Dongwook Yoon <dy252@cornell.edu>
Awais Qureshi <awais.qureshi@arbisoft.com>
Eric Fischer <efischer@edx.org>
Brian Beggs <macdiesel@gmail.com>
Bill DeRusha <bill@edx.org>
\ No newline at end of file
""" Tests for library reindex command """
import sys
import contextlib
import ddt
from django.core.management import call_command, CommandError
import mock
......@@ -8,6 +6,7 @@ import mock
from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from common.test.utils import nostderr
from xmodule.modulestore.tests.factories import CourseFactory, LibraryFactory
from opaque_keys import InvalidKeyError
......@@ -16,27 +15,6 @@ from contentstore.management.commands.reindex_library import Command as ReindexC
from contentstore.courseware_index import SearchIndexingError
@contextlib.contextmanager
def nostderr():
"""
ContextManager to suppress stderr messages
http://stackoverflow.com/a/1810086/882918
"""
savestderr = sys.stderr
class Devnull(object):
""" /dev/null incarnation as output-stream-like object """
def write(self, _):
""" Write method - just does nothing"""
pass
sys.stderr = Devnull()
try:
yield
finally:
sys.stderr = savestderr
@ddt.ddt
class TestReindexLibrary(ModuleStoreTestCase):
""" Tests for library reindex command """
......
"""
General testing utilities.
"""
import sys
from contextlib import contextmanager
@contextmanager
def nostderr():
"""
ContextManager to suppress stderr messages
http://stackoverflow.com/a/1810086/882918
"""
savestderr = sys.stderr
class Devnull(object):
""" /dev/null incarnation as output-stream-like object """
def write(self, _):
""" Write method - just does nothing"""
pass
sys.stderr = Devnull()
try:
yield
finally:
sys.stderr = savestderr
""" Management command to update course_teams' search index. """
from django.core.management import BaseCommand, CommandError
from django.core.exceptions import ObjectDoesNotExist
from django.conf import settings
from optparse import make_option
from textwrap import dedent
from teams.models import CourseTeam
class Command(BaseCommand):
"""
Command to reindex course_teams (single, multiple or all available).
Examples:
./manage.py reindex_course_team team1 team2 - reindexes course teams with team_ids team1 and team2
./manage.py reindex_course_team --all - reindexes all available course teams
"""
help = dedent(__doc__)
can_import_settings = True
args = "<course_team_id course_team_id ...>"
option_list = BaseCommand.option_list + (
make_option(
'--all',
action='store_true',
dest='all',
default=False,
help='Reindex all course teams'
),
)
def _get_course_team(self, team_id):
""" Returns course_team object from team_id. """
try:
result = CourseTeam.objects.get(team_id=team_id)
except ObjectDoesNotExist:
raise CommandError(u"Argument {0} is not a course_team team_id".format(team_id))
return result
def handle(self, *args, **options):
"""
By convention set by django developers, this method actually executes command's actions.
So, there could be no better docstring than emphasize this once again.
"""
# This is ugly, but there is a really strange circular dependency that doesn't
# happen anywhere else that I can't figure out how to avoid it :(
from teams.search_indexes import CourseTeamIndexer
if len(args) == 0 and not options.get('all', False):
raise CommandError(u"reindex_course_team requires one or more arguments: <course_team_id>")
elif not settings.FEATURES.get('ENABLE_TEAMS_SEARCH', False):
raise CommandError(u"ENABLE_TEAMS_SEARCH must be enabled")
if options.get('all', False):
course_teams = CourseTeam.objects.all()
else:
course_teams = map(self._get_course_team, args)
for course_team in course_teams:
print "Indexing {id}".format(id=course_team.team_id)
CourseTeamIndexer.index(course_team)
""" Tests for course_team reindex command """
import ddt
import mock
from mock import patch
from django.core.management import call_command, CommandError
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from common.test.utils import nostderr
from opaque_keys.edx.keys import CourseKey
from teams.tests.factories import CourseTeamFactory
from teams.search_indexes import CourseTeamIndexer
from search.search_engine_base import SearchEngine
COURSE_KEY1 = CourseKey.from_string('edx/history/1')
@ddt.ddt
class ReindexCourseTeamTest(SharedModuleStoreTestCase):
"""Tests for the ReindexCourseTeam command"""
def setUp(self):
"""
Set up tests.
"""
super(ReindexCourseTeamTest, self).setUp()
self.team1 = CourseTeamFactory(course_id=COURSE_KEY1, team_id='team1')
self.team2 = CourseTeamFactory(course_id=COURSE_KEY1, team_id='team2')
self.team3 = CourseTeamFactory(course_id=COURSE_KEY1, team_id='team3')
self.search_engine = SearchEngine.get_search_engine(index='index_course_team')
def test_given_no_arguments_raises_command_error(self):
""" Test that raises CommandError for incorrect arguments. """
with self.assertRaises(SystemExit), nostderr():
with self.assertRaisesRegexp(CommandError, ".* requires one or more arguments .*"):
call_command('reindex_course_team')
def test_teams_search_flag_disabled_raises_command_error(self):
""" Test that raises CommandError for disabled feature flag. """
with mock.patch('django.conf.settings.FEATURES') as features:
features.return_value = {"ENABLE_TEAMS_SEARCH": False}
with self.assertRaises(SystemExit), nostderr():
with self.assertRaisesRegexp(CommandError, ".* ENABLE_TEAMS_SEARCH must be enabled .*"):
call_command('reindex_course_team')
def test_given_invalid_team_id_raises_command_error(self):
""" Test that raises CommandError for invalid team id. """
with self.assertRaises(SystemExit), nostderr():
with self.assertRaisesRegexp(CommandError, ".* Argument {0} is not a course_team id .*"):
call_command('reindex_course_team', u'team4')
@patch.object(CourseTeamIndexer, 'index')
def test_single_team_id(self, mock_index):
""" Test that command indexes a single passed team. """
call_command('reindex_course_team', self.team1.team_id)
mock_index.assert_called_once_with(self.team1)
mock_index.reset_mock()
@patch.object(CourseTeamIndexer, 'index')
def test_multiple_team_id(self, mock_index):
""" Test that command indexes multiple passed teams. """
call_command('reindex_course_team', self.team1.team_id, self.team2.team_id)
mock_index.assert_any_call(self.team1)
mock_index.assert_any_call(self.team2)
mock_index.reset_mock()
@patch.object(CourseTeamIndexer, 'index')
def test_all_teams(self, mock_index):
""" Test that command indexes all teams. """
call_command('reindex_course_team', all=True)
mock_index.assert_any_call(self.team1)
mock_index.assert_any_call(self.team2)
mock_index.assert_any_call(self.team3)
mock_index.reset_mock()
""" Search index used to load data into elasticsearch"""
from django.conf import settings
from django.db.models.signals import post_save
from django.dispatch import receiver
from search.search_engine_base import SearchEngine
from .serializers import CourseTeamSerializer, CourseTeam
class CourseTeamIndexer(object):
"""
This is the index object for searching and storing CourseTeam model instances.
"""
INDEX_NAME = "course_team_index"
DOCUMENT_TYPE_NAME = "course_team"
ENABLE_SEARCH_KEY = "ENABLE_TEAMS_SEARCH"
def __init__(self, course_team):
self.course_team = course_team
def data(self):
"""
Uses the CourseTeamSerializer to create a serialized course_team object.
Adds in additional text and pk fields.
Removes membership relation.
Returns serialized object with additional search fields.
"""
serialized_course_team = CourseTeamSerializer(self.course_team).data
# Save the primary key so we can load the full objects easily after we search
serialized_course_team['pk'] = self.course_team.pk
# Don't save the membership relations in elasticsearch
serialized_course_team.pop('membership', None)
# add generally searchable content
serialized_course_team['content'] = {
'text': self.content_text()
}
return serialized_course_team
def content_text(self):
"""
Generate the text field used for general search.
"""
return "{name}\n{description}\n{country}\n{language}".format(
name=self.course_team.name.encode('utf-8'),
description=self.course_team.description.encode('utf-8'),
country=self.course_team.country.name.format(),
language=self._language_name()
)
def _language_name(self):
"""
Convert the language from code to long name.
"""
languages = dict(settings.ALL_LANGUAGES)
try:
return languages[self.course_team.language]
except KeyError:
return self.course_team.language
@classmethod
def index(cls, course_team):
"""
Update index with course_team object (if feature is enabled).
"""
if cls.search_is_enabled():
search_engine = cls.engine()
serialized_course_team = CourseTeamIndexer(course_team).data()
search_engine.index(cls.DOCUMENT_TYPE_NAME, [serialized_course_team])
@classmethod
def engine(cls):
"""
Return course team search engine (if feature is enabled).
"""
if cls.search_is_enabled():
return SearchEngine.get_search_engine(index=cls.INDEX_NAME)
@classmethod
def search_is_enabled(cls):
"""
Return boolean of whether course team indexing is enabled.
"""
return settings.FEATURES.get(cls.ENABLE_SEARCH_KEY, False)
@receiver(post_save, sender=CourseTeam)
def course_team_post_save_callback(**kwargs):
"""
Reindex object after save.
"""
CourseTeamIndexer.index(kwargs['instance'])
......@@ -17,6 +17,7 @@ from student.tests.factories import UserFactory, AdminFactory, CourseEnrollmentF
from student.models import CourseEnrollment
from xmodule.modulestore.tests.factories import CourseFactory
from .factories import CourseTeamFactory, LAST_ACTIVITY_AT
from ..search_indexes import CourseTeamIndexer
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from django_comment_common.models import Role, FORUM_ROLE_COMMUNITY_TA
......@@ -193,6 +194,9 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
username='student_enrolled_other_course_not_on_team'
)
# clear the teams search index before rebuilding teams
CourseTeamIndexer.engine().destroy()
# 'solar team' is intentionally lower case to test case insensitivity in name ordering
self.test_team_1 = CourseTeamFactory.create(
name=u'sólar team',
......@@ -208,6 +212,14 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
course_id=self.test_course_2.id,
topic_id='topic_6'
)
self.test_team_7 = CourseTeamFactory.create(
name='Search',
description='queryable text',
country='GS',
language='to',
course_id=self.test_course_2.id,
topic_id='topic_7'
)
self.test_team_name_id_map = {team.name: team for team in (
self.test_team_1,
......@@ -418,7 +430,7 @@ class TestListTeamsAPI(TeamAPITestCase):
self.verify_names(
{'course_id': self.test_course_2.id},
200,
['Another Team', 'Public Profile Team'],
['Another Team', 'Public Profile Team', 'Search'],
user='staff'
)
......@@ -428,11 +440,6 @@ class TestListTeamsAPI(TeamAPITestCase):
def test_filter_include_inactive(self):
self.verify_names({'include_inactive': True}, 200, ['Coal Team', 'Nuclear Team', u'sólar team', 'Wind Team'])
# Text search is not yet implemented, so this should return HTTP
# 400 for now
def test_filter_text_search(self):
self.verify_names({'text_search': 'foobar'}, 400)
@ddt.data(
(None, 200, ['Nuclear Team', u'sólar team', 'Wind Team']),
('name', 200, ['Nuclear Team', u'sólar team', 'Wind Team']),
......@@ -455,6 +462,10 @@ class TestListTeamsAPI(TeamAPITestCase):
data = {'order_by': field} if field else {}
self.verify_names(data, status, names)
def test_order_by_with_text_search(self):
data = {'order_by': 'name', 'text_search': 'search'}
self.verify_names(data, 400, [])
@ddt.data((404, {'course_id': 'no/such/course'}), (400, {'topic_id': 'no_such_topic'}))
@ddt.unpack
def test_no_results(self, status, data):
......@@ -487,6 +498,23 @@ class TestListTeamsAPI(TeamAPITestCase):
)
self.verify_expanded_public_user(result['results'][0]['membership'][0]['user'])
@ddt.data(
('search', ['Search']),
('queryable', ['Search']),
('Tonga', ['Search']),
('Island', ['Search']),
('search queryable', []),
('team', ['Another Team', 'Public Profile Team']),
)
@ddt.unpack
def test_text_search(self, text_search, expected_team_names):
self.verify_names(
{'course_id': self.test_course_2.id, 'text_search': text_search},
200,
expected_team_names,
user='student_enrolled_public_profile'
)
@ddt.ddt
class TestCreateTeamAPI(TeamAPITestCase):
......
......@@ -29,6 +29,7 @@ from openedx.core.lib.api.view_utils import (
ExpandableFieldViewMixin
)
from openedx.core.lib.api.serializers import PaginationSerializer
from openedx.core.lib.api.paginators import paginate_search_results
from xmodule.modulestore.django import modulestore
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
......@@ -49,10 +50,12 @@ from .serializers import (
PaginatedMembershipSerializer,
add_team_count
)
from .search_indexes import CourseTeamIndexer
from .errors import AlreadyOnTeamInCourse, NotEnrolledInCourseForTeam
TEAM_MEMBERSHIPS_PER_PAGE = 2
TOPICS_PER_PAGE = 12
MAXIMUM_SEARCH_SIZE = 100000
class TeamsDashboardView(View):
......@@ -168,9 +171,12 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
* topic_id: Filters the result to teams associated with the given
topic.
* text_search: Currently not supported.
* text_search: Searches for full word matches on the name, description,
country, and language fields. NOTES: Search is on full names for countries
and languages, not the ISO codes. Text_search cannot be requested along with
with order_by. Searching relies on the ENABLE_TEAMS_SEARCH flag being set to True.
* order_by: Must be one of the following:
* order_by: Cannot be called along with with text_search. Must be one of the following:
* name: Orders results by case insensitive team name (default).
......@@ -313,6 +319,12 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
status=status.HTTP_400_BAD_REQUEST
)
if 'text_search' in request.QUERY_PARAMS and 'order_by' in request.QUERY_PARAMS:
return Response(
build_api_error(ugettext_noop("text_search and order_by cannot be provided together")),
status=status.HTTP_400_BAD_REQUEST
)
if 'topic_id' in request.QUERY_PARAMS:
topic_id = request.QUERY_PARAMS['topic_id']
if topic_id not in [topic['id'] for topic in course_module.teams_configuration['topics']]:
......@@ -324,14 +336,28 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
result_filter.update({'topic_id': request.QUERY_PARAMS['topic_id']})
if 'include_inactive' in request.QUERY_PARAMS and request.QUERY_PARAMS['include_inactive'].lower() == 'true':
del result_filter['is_active']
if 'text_search' in request.QUERY_PARAMS:
return Response(
build_api_error(ugettext_noop("text_search is not yet supported.")),
status=status.HTTP_400_BAD_REQUEST
if 'text_search' in request.QUERY_PARAMS and CourseTeamIndexer.search_is_enabled():
search_engine = CourseTeamIndexer.engine()
text_search = request.QUERY_PARAMS['text_search'].encode('utf-8')
result_filter.update({'course_id': course_id_string})
search_results = search_engine.search(
query_string=text_search,
field_dictionary=result_filter,
size=MAXIMUM_SEARCH_SIZE,
)
queryset = CourseTeam.objects.filter(**result_filter)
paginated_results = paginate_search_results(
CourseTeam,
search_results,
self.get_paginate_by(),
self.get_page()
)
serializer = self.get_pagination_serializer(paginated_results)
else:
queryset = CourseTeam.objects.filter(**result_filter)
order_by_input = request.QUERY_PARAMS.get('order_by', 'name')
if order_by_input == 'name':
queryset = queryset.extra(select={'lower_name': "lower(name)"})
......@@ -355,6 +381,7 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
page = self.paginate_queryset(queryset)
serializer = self.get_pagination_serializer(page)
serializer.context.update({'sort_order': order_by_input}) # pylint: disable=maybe-no-member
return Response(serializer.data) # pylint: disable=maybe-no-member
def post(self, request):
......@@ -408,6 +435,14 @@ class TeamsListView(ExpandableFieldViewMixin, GenericAPIView):
team.add_user(request.user)
return Response(CourseTeamSerializer(team).data)
def get_page(self):
""" Returns page number specified in args, params, or defaults to 1. """
# This code is taken from within the GenericAPIView#paginate_queryset method.
# We need need access to the page outside of that method for our paginate_search_results method
page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
return page_kwarg or page_query_param or 1
class IsEnrolledOrIsStaff(permissions.BasePermission):
"""Permission that checks to see if the user is enrolled in the course or is staff."""
......
......@@ -621,7 +621,8 @@ PDF_RECEIPT_COBRAND_LOGO_HEIGHT_MM = ENV_TOKENS.get(
if FEATURES.get('ENABLE_COURSEWARE_SEARCH') or \
FEATURES.get('ENABLE_DASHBOARD_SEARCH') or \
FEATURES.get('ENABLE_COURSE_DISCOVERY'):
FEATURES.get('ENABLE_COURSE_DISCOVERY') or \
FEATURES.get('ENABLE_TEAMS_SEARCH'):
# Use ElasticSearch as the search engine herein
SEARCH_ENGINE = "search.elastic.ElasticSearchEngine"
......
......@@ -401,6 +401,9 @@ FEATURES = {
# Teams feature
'ENABLE_TEAMS': True,
# Enable indexing teams for search
'ENABLE_TEAMS_SEARCH': False,
# Show video bumper in LMS
'ENABLE_VIDEO_BUMPER': False,
......
......@@ -484,6 +484,9 @@ FEATURES['ENABLE_EDXNOTES'] = True
# Enable teams feature for tests.
FEATURES['ENABLE_TEAMS'] = True
# Enable indexing teams for search
FEATURES['ENABLE_TEAMS_SEARCH'] = True
# Add milestones to Installed apps for testing
INSTALLED_APPS += ('milestones', 'openedx.core.djangoapps.call_stack_manager')
......
""" Paginatator methods for edX API implementations."""
from django.http import Http404
from django.utils.translation import ugettext as _
from django.core.paginator import Paginator, InvalidPage
def paginate_search_results(object_class, search_results, page_size, page):
"""
Takes edx-search results and returns a Page object populated
with db objects for that page.
:param object_class: Model class to use when querying the db for objects.
:param search_results: edX-search results.
:param page_size: Number of results per page.
:param page: Page number.
:return: Paginator object with model objects
"""
paginator = Paginator(search_results['results'], page_size)
# This code is taken from within the GenericAPIView#paginate_queryset method.
# It is common code, but
try:
page_number = paginator.validate_number(page)
except InvalidPage:
if page == 'last':
page_number = paginator.num_pages
else:
raise Http404(_("Page is not 'last', nor can it be converted to an int."))
try:
paged_results = paginator.page(page_number)
except InvalidPage as e: # pylint: disable=invalid-name
raise Http404(_('Invalid page (%(page_number)s): %(message)s') % {
'page_number': page_number,
'message': str(e)
})
search_queryset_pks = [item['data']['pk'] for item in paged_results.object_list]
queryset = object_class.objects.filter(pk__in=search_queryset_pks)
def ordered_objects(primary_key):
""" Returns database object matching the search result object"""
for obj in queryset:
if obj.pk == primary_key:
return obj
# map over the search results and get a list of database objects in the same order
object_results = map(ordered_objects, search_queryset_pks)
paged_results.object_list = object_results
return paged_results
""" Tests paginator methods """
import ddt
from mock import Mock, MagicMock
from unittest import TestCase
from django.http import Http404
from openedx.core.lib.api.paginators import paginate_search_results
@ddt.ddt
class PaginateSearchResultsTestCase(TestCase):
"""Test cases for paginate_search_results method"""
def setUp(self):
super(PaginateSearchResultsTestCase, self).setUp()
self.default_size = 6
self.default_page = 1
self.search_results = {
"count": 3,
"took": 1,
"results": [
{
'_id': 0,
'data': {
'pk': 0,
'name': 'object 0'
}
},
{
'_id': 1,
'data': {
'pk': 1,
'name': 'object 1'
}
},
{
'_id': 2,
'data': {
'pk': 2,
'name': 'object 2'
}
},
{
'_id': 3,
'data': {
'pk': 3,
'name': 'object 3'
}
},
{
'_id': 4,
'data': {
'pk': 4,
'name': 'object 4'
}
},
{
'_id': 5,
'data': {
'pk': 5,
'name': 'object 5'
}
},
]
}
self.mock_model = Mock()
self.mock_model.objects = Mock()
self.mock_model.objects.filter = Mock()
@ddt.data(
(1, 1, True),
(1, 3, True),
(1, 5, True),
(1, 10, False),
(2, 1, True),
(2, 3, False),
(2, 5, False),
)
@ddt.unpack
def test_paginated_results(self, page_number, page_size, has_next):
""" Test the page returned has the expected db objects and acts
like a proper page object.
"""
id_range = get_object_range(page_number, page_size)
db_objects = [build_mock_object(obj_id) for obj_id in id_range]
self.mock_model.objects.filter = MagicMock(return_value=db_objects)
page = paginate_search_results(self.mock_model, self.search_results, page_size, page_number)
self.mock_model.objects.filter.assert_called_with(pk__in=id_range)
self.assertEquals(db_objects, page.object_list)
self.assertTrue(page.number, page_number)
self.assertEquals(page.has_next(), has_next)
def test_paginated_results_last_keyword(self):
""" Test the page returned has the expected db objects and acts
like a proper page object using 'last' keyword.
"""
page_number = 2
page_size = 3
id_range = get_object_range(page_number, page_size)
db_objects = [build_mock_object(obj_id) for obj_id in id_range]
self.mock_model.objects.filter = MagicMock(return_value=db_objects)
page = paginate_search_results(self.mock_model, self.search_results, self.default_size, 'last')
self.mock_model.objects.filter.assert_called_with(pk__in=id_range)
self.assertEquals(db_objects, page.object_list)
self.assertTrue(page.number, page_number)
self.assertFalse(page.has_next())
@ddt.data(10, -1, 0, 'str')
def test_invalid_page_number(self, page_num):
""" Test that a Http404 error is raised with non-integer and out-of-range pages
"""
with self.assertRaises(Http404):
paginate_search_results(self.mock_model, self.search_results, self.default_size, page_num)
def build_mock_object(obj_id):
""" Build a mock object with the passed id"""
mock_object = Mock()
object_config = {
'pk': obj_id,
'name': "object {}".format(obj_id)
}
mock_object.configure_mock(**object_config)
return mock_object
def get_object_range(page, page_size):
""" Get the range of expected object ids given a page and page size.
This will take into account the max_id of the sample data. Currently 5.
"""
max_id = 5
start = min((page - 1) * page_size, max_id)
end = min(start + page_size, max_id + 1)
return range(start, end)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment