Commit de68c304 by Greg Price

Merge pull request #8530 from edx/gprice/discussion-api-text-search

Add text_search parameter to discussion API
parents fa4dcdeb fecbdcd0
...@@ -223,7 +223,7 @@ def get_course_topics(request, course_key): ...@@ -223,7 +223,7 @@ def get_course_topics(request, course_key):
} }
def get_thread_list(request, course_key, page, page_size, topic_id_list=None): def get_thread_list(request, course_key, page, page_size, topic_id_list=None, text_search=None):
""" """
Return the list of all discussion threads pertaining to the given course Return the list of all discussion threads pertaining to the given course
...@@ -234,16 +234,30 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None): ...@@ -234,16 +234,30 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None):
page: The page number (1-indexed) to retrieve page: The page number (1-indexed) to retrieve
page_size: The number of threads to retrieve per page page_size: The number of threads to retrieve per page
topic_id_list: The list of topic_ids to get the discussion threads for topic_id_list: The list of topic_ids to get the discussion threads for
text_search A text search query string to match
Note that topic_id_list and text_search are mutually exclusive.
Returns: Returns:
A paginated result containing a list of threads; see A paginated result containing a list of threads; see
discussion_api.views.ThreadViewSet for more detail. discussion_api.views.ThreadViewSet for more detail.
Raises:
ValueError: if more than one of the mutually exclusive parameters is
provided
Http404: if the requesting user does not have access to the requested course
or a page beyond the last is requested
""" """
exclusive_param_count = sum(1 for param in [topic_id_list, text_search] if param)
if exclusive_param_count > 1: # pragma: no cover
raise ValueError("More than one mutually exclusive param passed to get_thread_list")
course = _get_course_or_404(course_key, request.user) course = _get_course_or_404(course_key, request.user)
context = get_context(course, request) context = get_context(course, request)
topic_ids_csv = ",".join(topic_id_list) if topic_id_list else None topic_ids_csv = ",".join(topic_id_list) if topic_id_list else None
threads, result_page, num_pages, _ = Thread.search({ threads, result_page, num_pages, text_search_rewrite = Thread.search({
"course_id": unicode(course.id), "course_id": unicode(course.id),
"group_id": ( "group_id": (
None if context["is_requester_privileged"] else None if context["is_requester_privileged"] else
...@@ -254,6 +268,7 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None): ...@@ -254,6 +268,7 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None):
"page": page, "page": page,
"per_page": page_size, "per_page": page_size,
"commentable_ids": topic_ids_csv, "commentable_ids": topic_ids_csv,
"text": text_search,
}) })
# The comments service returns the last page of results if the requested # The comments service returns the last page of results if the requested
# page is beyond the last page, but we want be consistent with DRF's general # page is beyond the last page, but we want be consistent with DRF's general
...@@ -262,7 +277,9 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None): ...@@ -262,7 +277,9 @@ def get_thread_list(request, course_key, page, page_size, topic_id_list=None):
raise Http404 raise Http404
results = [ThreadSerializer(thread, context=context).data for thread in threads] results = [ThreadSerializer(thread, context=context).data for thread in threads]
return get_paginated_data(request, results, page, num_pages) ret = get_paginated_data(request, results, page, num_pages)
ret["text_search_rewrite"] = text_search_rewrite
return ret
def get_comment_list(request, thread_id, endorsed, page, page_size): def get_comment_list(request, thread_id, endorsed, page, page_size):
......
...@@ -45,8 +45,11 @@ class ThreadListGetForm(_PaginationForm): ...@@ -45,8 +45,11 @@ class ThreadListGetForm(_PaginationForm):
""" """
A form to validate query parameters in the thread list retrieval endpoint A form to validate query parameters in the thread list retrieval endpoint
""" """
EXCLUSIVE_PARAMS = ["topic_id", "text_search"]
course_id = CharField() course_id = CharField()
topic_id = TopicIdField(required=False) topic_id = TopicIdField(required=False)
text_search = CharField(required=False)
def clean_course_id(self): def clean_course_id(self):
"""Validate course_id""" """Validate course_id"""
...@@ -56,6 +59,19 @@ class ThreadListGetForm(_PaginationForm): ...@@ -56,6 +59,19 @@ class ThreadListGetForm(_PaginationForm):
except InvalidKeyError: except InvalidKeyError:
raise ValidationError("'{}' is not a valid course id".format(value)) raise ValidationError("'{}' is not a valid course id".format(value))
def clean(self):
cleaned_data = super(ThreadListGetForm, self).clean()
exclusive_params_count = sum(
1 for param in self.EXCLUSIVE_PARAMS if cleaned_data.get(param)
)
if exclusive_params_count > 1:
raise ValidationError(
"The following query parameters are mutually exclusive: {}".format(
", ".join(self.EXCLUSIVE_PARAMS)
)
)
return cleaned_data
class ThreadActionsForm(Form): class ThreadActionsForm(Form):
""" """
......
...@@ -471,7 +471,15 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -471,7 +471,15 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
self.author = UserFactory.create() self.author = UserFactory.create()
self.cohort = CohortFactory.create(course_id=self.course.id) self.cohort = CohortFactory.create(course_id=self.course.id)
def get_thread_list(self, threads, page=1, page_size=1, num_pages=1, course=None, topic_id_list=None): def get_thread_list(
self,
threads,
page=1,
page_size=1,
num_pages=1,
course=None,
topic_id_list=None,
):
""" """
Register the appropriate comments service response, then call Register the appropriate comments service response, then call
get_thread_list and return the result. get_thread_list and return the result.
...@@ -502,6 +510,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -502,6 +510,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
"results": [], "results": [],
"next": None, "next": None,
"previous": None, "previous": None,
"text_search_rewrite": None,
} }
) )
...@@ -636,6 +645,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -636,6 +645,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
"results": expected_threads, "results": expected_threads,
"next": None, "next": None,
"previous": None, "previous": None,
"text_search_rewrite": None,
} }
) )
...@@ -670,6 +680,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -670,6 +680,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
"results": [], "results": [],
"next": "http://testserver/test_path?page=2", "next": "http://testserver/test_path?page=2",
"previous": None, "previous": None,
"text_search_rewrite": None,
} }
) )
self.assertEqual( self.assertEqual(
...@@ -678,6 +689,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -678,6 +689,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
"results": [], "results": [],
"next": "http://testserver/test_path?page=3", "next": "http://testserver/test_path?page=3",
"previous": "http://testserver/test_path?page=1", "previous": "http://testserver/test_path?page=1",
"text_search_rewrite": None,
} }
) )
self.assertEqual( self.assertEqual(
...@@ -686,6 +698,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -686,6 +698,7 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
"results": [], "results": [],
"next": None, "next": None,
"previous": "http://testserver/test_path?page=2", "previous": "http://testserver/test_path?page=2",
"text_search_rewrite": None,
} }
) )
...@@ -694,6 +707,34 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -694,6 +707,34 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
with self.assertRaises(Http404): with self.assertRaises(Http404):
get_thread_list(self.request, self.course.id, page=4, page_size=10) get_thread_list(self.request, self.course.id, page=4, page_size=10)
@ddt.data(None, "rewritten search string")
def test_text_search(self, text_search_rewrite):
self.register_get_threads_search_response([], text_search_rewrite)
self.assertEqual(
get_thread_list(
self.request,
self.course.id,
page=1,
page_size=10,
text_search="test search string"
),
{
"results": [],
"next": None,
"previous": None,
"text_search_rewrite": text_search_rewrite,
}
)
self.assert_last_query_params({
"course_id": [unicode(self.course.id)],
"sort_key": ["date"],
"sort_order": ["desc"],
"page": ["1"],
"per_page": ["10"],
"recursive": ["False"],
"text": ["test search string"],
})
@ddt.ddt @ddt.ddt
class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase): class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
......
""" """
Tests for Discussion API forms Tests for Discussion API forms
""" """
import itertools
from unittest import TestCase from unittest import TestCase
from urllib import urlencode from urllib import urlencode
import ddt
from django.http import QueryDict from django.http import QueryDict
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
...@@ -63,6 +66,7 @@ class PaginationTestMixin(object): ...@@ -63,6 +66,7 @@ class PaginationTestMixin(object):
self.assert_field_value("page_size", 100) self.assert_field_value("page_size", 100)
@ddt.ddt
class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for ThreadListGetForm""" """Tests for ThreadListGetForm"""
FORM_CLASS = ThreadListGetForm FORM_CLASS = ThreadListGetForm
...@@ -81,7 +85,6 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -81,7 +85,6 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
) )
def test_basic(self): def test_basic(self):
self.form_data.setlist("topic_id", ["example topic_id", "example 2nd topic_id"])
form = self.get_form(expected_valid=True) form = self.get_form(expected_valid=True)
self.assertEqual( self.assertEqual(
form.cleaned_data, form.cleaned_data,
...@@ -89,10 +92,27 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -89,10 +92,27 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"course_id": CourseLocator.from_string("Foo/Bar/Baz"), "course_id": CourseLocator.from_string("Foo/Bar/Baz"),
"page": 2, "page": 2,
"page_size": 13, "page_size": 13,
"topic_id": ["example topic_id", "example 2nd topic_id"], "topic_id": [],
"text_search": "",
} }
) )
def test_topic_id(self):
self.form_data.setlist("topic_id", ["example topic_id", "example 2nd topic_id"])
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["topic_id"],
["example topic_id", "example 2nd topic_id"],
)
def test_text_search(self):
self.form_data["text_search"] = "test search string"
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["text_search"],
"test search string",
)
def test_missing_course_id(self): def test_missing_course_id(self):
self.form_data.pop("course_id") self.form_data.pop("course_id")
self.assert_error("course_id", "This field is required.") self.assert_error("course_id", "This field is required.")
...@@ -105,6 +125,14 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -105,6 +125,14 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
self.form_data.setlist("topic_id", ["", "not empty"]) self.form_data.setlist("topic_id", ["", "not empty"])
self.assert_error("topic_id", "This field cannot be empty.") self.assert_error("topic_id", "This field cannot be empty.")
@ddt.data(*itertools.combinations(["topic_id", "text_search"], 2))
def test_mutually_exclusive(self, params):
self.form_data.update({param: "dummy" for param in params})
self.assert_error(
"__all__",
"The following query parameters are mutually exclusive: topic_id, text_search"
)
class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for CommentListGetForm""" """Tests for CommentListGetForm"""
......
...@@ -212,6 +212,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -212,6 +212,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"results": expected_threads, "results": expected_threads,
"next": "http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&page=2", "next": "http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&page=2",
"previous": None, "previous": None,
"text_search_rewrite": None,
} }
) )
self.assert_last_query_params({ self.assert_last_query_params({
...@@ -244,6 +245,28 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -244,6 +245,28 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"recursive": ["False"], "recursive": ["False"],
}) })
def test_text_search(self):
self.register_get_user_response(self.user)
self.register_get_threads_search_response([], None)
response = self.client.get(
self.url,
{"course_id": unicode(self.course.id), "text_search": "test search string"}
)
self.assert_response_correct(
response,
200,
{"results": [], "next": None, "previous": None, "text_search_rewrite": None}
)
self.assert_last_query_params({
"course_id": [unicode(self.course.id)],
"sort_key": ["date"],
"sort_order": ["desc"],
"page": ["1"],
"per_page": ["10"],
"recursive": ["False"],
"text": ["test search string"],
})
@httpretty.activate @httpretty.activate
class ThreadViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): class ThreadViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
......
...@@ -71,6 +71,20 @@ class CommentsServiceMockMixin(object): ...@@ -71,6 +71,20 @@ class CommentsServiceMockMixin(object):
status=200 status=200
) )
def register_get_threads_search_response(self, threads, rewrite):
"""Register a mock response for GET on the CS thread search endpoint"""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/search/threads",
body=json.dumps({
"collection": threads,
"page": 1,
"num_pages": 1,
"corrected_text": rewrite,
}),
status=200
)
def register_post_thread_response(self, thread_data): def register_post_thread_response(self, thread_data):
"""Register a mock response for POST on the CS commentable endpoint""" """Register a mock response for POST on the CS commentable endpoint"""
httpretty.register_uri( httpretty.register_uri(
......
...@@ -137,6 +137,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -137,6 +137,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
multiple topic_id queries to retrieve threads from multiple topics multiple topic_id queries to retrieve threads from multiple topics
at once. at once.
* text_search: A search string to match. Any thread whose content
(including the bodies of comments in the thread) matches the search
string will be returned.
**POST Parameters**: **POST Parameters**:
* course_id (required): The course to create the thread in * course_id (required): The course to create the thread in
...@@ -166,6 +170,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -166,6 +170,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* previous: The URL of the previous page (or null if last page) * previous: The URL of the previous page (or null if last page)
* text_search_rewrite: The search string to which the text_search
parameter was rewritten in order to match threads (e.g. for spelling
correction)
**POST/PATCH response values**: **POST/PATCH response values**:
* id: The id of the thread * id: The id of the thread
...@@ -217,6 +225,7 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -217,6 +225,7 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["page"], form.cleaned_data["page"],
form.cleaned_data["page_size"], form.cleaned_data["page_size"],
form.cleaned_data["topic_id"], form.cleaned_data["topic_id"],
form.cleaned_data["text_search"],
) )
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment