Commit f19af520 by wajeeha-khalid

MA-2050: add optional user profile details in GET thread & comment

parent c681c567
...@@ -51,6 +51,7 @@ class ThreadListGetForm(_PaginationForm): ...@@ -51,6 +51,7 @@ class ThreadListGetForm(_PaginationForm):
choices=[(choice, choice) for choice in ["asc", "desc"]], choices=[(choice, choice) for choice in ["asc", "desc"]],
required=False required=False
) )
requested_fields = MultiValueField(required=False)
def clean_order_by(self): def clean_order_by(self):
"""Return a default choice""" """Return a default choice"""
...@@ -106,6 +107,7 @@ class CommentListGetForm(_PaginationForm): ...@@ -106,6 +107,7 @@ class CommentListGetForm(_PaginationForm):
""" """
thread_id = CharField() thread_id = CharField()
endorsed = ExtendedNullBooleanField(required=False) endorsed = ExtendedNullBooleanField(required=False)
requested_fields = MultiValueField(required=False)
class CommentActionsForm(Form): class CommentActionsForm(Form):
...@@ -115,3 +117,10 @@ class CommentActionsForm(Form): ...@@ -115,3 +117,10 @@ class CommentActionsForm(Form):
""" """
voted = BooleanField(required=False) voted = BooleanField(required=False)
abuse_flagged = BooleanField(required=False) abuse_flagged = BooleanField(required=False)
class CommentGetForm(_PaginationForm):
"""
A form to validate query parameters in the comment retrieval endpoint
"""
requested_fields = MultiValueField(required=False)
...@@ -40,7 +40,8 @@ from discussion_api.tests.utils import ( ...@@ -40,7 +40,8 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin, CommentsServiceMockMixin,
make_minimal_cs_comment, make_minimal_cs_comment,
make_minimal_cs_thread, make_minimal_cs_thread,
make_paginated_api_response make_paginated_api_response,
ProfileImageTestMixin,
) )
from django_comment_common.models import ( from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_ADMINISTRATOR,
...@@ -1148,7 +1149,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1148,7 +1149,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_basic_query_params(self): def test_basic_query_params(self):
self.get_comment_list( self.get_comment_list(
self.make_minimal_cs_thread({ self.make_minimal_cs_thread({
"children": [make_minimal_cs_comment()], "children": [make_minimal_cs_comment({"username": self.user.username})],
"resp_total": 71 "resp_total": 71
}), }),
page=6, page=6,
...@@ -1254,8 +1255,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1254,8 +1255,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_question_content(self): def test_question_content(self):
thread = self.make_minimal_cs_thread({ thread = self.make_minimal_cs_thread({
"thread_type": "question", "thread_type": "question",
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment"})], "endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment", "username": self.user.username})],
"non_endorsed_responses": [make_minimal_cs_comment({"id": "non_endorsed_comment"})], "non_endorsed_responses": [make_minimal_cs_comment({
"id": "non_endorsed_comment", "username": self.user.username
})],
"non_endorsed_resp_total": 1, "non_endorsed_resp_total": 1,
}) })
...@@ -1274,7 +1277,8 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1274,7 +1277,8 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
"anonymous": True, "anonymous": True,
"children": [ "children": [
make_minimal_cs_comment({ make_minimal_cs_comment({
"endorsement": {"user_id": str(self.author.id), "time": "2015-05-18T12:34:56Z"} "username": self.user.username,
"endorsement": {"user_id": str(self.author.id), "time": "2015-05-18T12:34:56Z"},
}) })
] ]
}) })
...@@ -1301,7 +1305,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1301,7 +1305,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
# number of responses is unrealistic but convenient for this test # number of responses is unrealistic but convenient for this test
thread = self.make_minimal_cs_thread({ thread = self.make_minimal_cs_thread({
"thread_type": thread_type, "thread_type": thread_type,
response_field: [make_minimal_cs_comment()], response_field: [make_minimal_cs_comment({"username": self.user.username})],
response_total_field: 5, response_total_field: 5,
}) })
...@@ -1337,9 +1341,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1337,9 +1341,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_question_endorsed_pagination(self): def test_question_endorsed_pagination(self):
thread = self.make_minimal_cs_thread({ thread = self.make_minimal_cs_thread({
"thread_type": "question", "thread_type": "question",
"endorsed_responses": [ "endorsed_responses": [make_minimal_cs_comment({
make_minimal_cs_comment({"id": "comment_{}".format(i)}) for i in range(10) "id": "comment_{}".format(i),
] "username": self.user.username
}) for i in range(10)]
}) })
def assert_page_correct(page, page_size, expected_start, expected_stop, expected_next, expected_prev): def assert_page_correct(page, page_size, expected_start, expected_stop, expected_next, expected_prev):
...@@ -1553,7 +1558,7 @@ class CreateThreadTest( ...@@ -1553,7 +1558,7 @@ class CreateThreadTest(
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user]) cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id) role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user] role.users = [self.user]
self.register_post_thread_response({}) self.register_post_thread_response({"username": self.user.username})
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["course_id"] = unicode(cohort_course.id) data["course_id"] = unicode(cohort_course.id)
if data_group_state == "group_is_none": if data_group_state == "group_is_none":
...@@ -1582,7 +1587,7 @@ class CreateThreadTest( ...@@ -1582,7 +1587,7 @@ class CreateThreadTest(
self.fail("Unexpected validation error: {}".format(ex)) self.fail("Unexpected validation error: {}".format(ex))
def test_following(self): def test_following(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_subscription_response(self.user) self.register_subscription_response(self.user)
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["following"] = "True" data["following"] = "True"
...@@ -1600,7 +1605,7 @@ class CreateThreadTest( ...@@ -1600,7 +1605,7 @@ class CreateThreadTest(
) )
def test_voted(self): def test_voted(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_thread_votes_response("test_id") self.register_thread_votes_response("test_id")
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["voted"] = "True" data["voted"] = "True"
...@@ -1616,7 +1621,7 @@ class CreateThreadTest( ...@@ -1616,7 +1621,7 @@ class CreateThreadTest(
) )
def test_abuse_flagged(self): def test_abuse_flagged(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_thread_flag_response("test_id") self.register_thread_flag_response("test_id")
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["abuse_flagged"] = "True" data["abuse_flagged"] = "True"
...@@ -1802,7 +1807,7 @@ class CreateCommentTest( ...@@ -1802,7 +1807,7 @@ class CreateCommentTest(
"user_id": str(self.user.id) if is_thread_author else str(self.user.id + 1), "user_id": str(self.user.id) if is_thread_author else str(self.user.id + 1),
}) })
) )
self.register_post_comment_response({}, "test_thread") self.register_post_comment_response({"username": self.user.username}, "test_thread")
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["endorsed"] = True data["endorsed"] = True
expected_error = ( expected_error = (
...@@ -1817,7 +1822,7 @@ class CreateCommentTest( ...@@ -1817,7 +1822,7 @@ class CreateCommentTest(
self.assertTrue(expected_error) self.assertTrue(expected_error)
def test_voted(self): def test_voted(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread") self.register_post_comment_response({"id": "test_comment", "username": self.user.username}, "test_thread")
self.register_comment_votes_response("test_comment") self.register_comment_votes_response("test_comment")
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["voted"] = "True" data["voted"] = "True"
...@@ -1833,7 +1838,7 @@ class CreateCommentTest( ...@@ -1833,7 +1838,7 @@ class CreateCommentTest(
) )
def test_abuse_flagged(self): def test_abuse_flagged(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread") self.register_post_comment_response({"id": "test_comment", "username": self.user.username}, "test_thread")
self.register_comment_flag_response("test_comment") self.register_comment_flag_response("test_comment")
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["abuse_flagged"] = "True" data["abuse_flagged"] = "True"
...@@ -1906,7 +1911,7 @@ class CreateCommentTest( ...@@ -1906,7 +1911,7 @@ class CreateCommentTest(
cohort.id + 1 cohort.id + 1
), ),
})) }))
self.register_post_comment_response({}, thread_id="cohort_thread") self.register_post_comment_response({"username": self.user.username}, thread_id="cohort_thread")
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["thread_id"] = "cohort_thread" data["thread_id"] = "cohort_thread"
expected_error = ( expected_error = (
......
...@@ -69,6 +69,7 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -69,6 +69,7 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"view": "", "view": "",
"order_by": "last_activity_at", "order_by": "last_activity_at",
"order_direction": "desc", "order_direction": "desc",
"requested_fields": set(),
} }
) )
...@@ -154,6 +155,14 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -154,6 +155,14 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
self.form_data[field] = value self.form_data[field] = value
self.assert_field_value(field, value) self.assert_field_value(field, value)
def test_requested_fields(self):
self.form_data["requested_fields"] = "profile_image"
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["requested_fields"],
{"profile_image"},
)
@ddt.ddt @ddt.ddt
class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
...@@ -177,7 +186,8 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -177,7 +186,8 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"thread_id": "deadbeef", "thread_id": "deadbeef",
"endorsed": False, "endorsed": False,
"page": 2, "page": 2,
"page_size": 13 "page_size": 13,
"requested_fields": set(),
} }
) )
...@@ -202,3 +212,11 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -202,3 +212,11 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
def test_invalid_endorsed(self): def test_invalid_endorsed(self):
self.form_data["endorsed"] = "invalid-boolean" self.form_data["endorsed"] = "invalid-boolean"
self.assert_error("endorsed", "Invalid Boolean Value.") self.assert_error("endorsed", "Invalid Boolean Value.")
def test_requested_fields(self):
self.form_data["requested_fields"] = {"profile_image"}
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["requested_fields"],
{"profile_image"},
)
...@@ -16,6 +16,7 @@ from discussion_api.tests.utils import ( ...@@ -16,6 +16,7 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin, CommentsServiceMockMixin,
make_minimal_cs_thread, make_minimal_cs_thread,
make_minimal_cs_comment, make_minimal_cs_comment,
ProfileImageTestMixin,
) )
from django_comment_common.models import ( from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_ADMINISTRATOR,
...@@ -237,7 +238,7 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe ...@@ -237,7 +238,7 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe
thread_data = self.make_cs_content({}) thread_data = self.make_cs_content({})
del thread_data["pinned"] del thread_data["pinned"]
self.register_get_thread_response(thread_data) self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"])) serialized = self.serialize(thread_data)
self.assertEqual(serialized["pinned"], False) self.assertEqual(serialized["pinned"], False)
def test_group(self): def test_group(self):
...@@ -255,14 +256,14 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe ...@@ -255,14 +256,14 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe
def test_response_count(self): def test_response_count(self):
thread_data = self.make_cs_content({"resp_total": 2}) thread_data = self.make_cs_content({"resp_total": 2})
self.register_get_thread_response(thread_data) self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"])) serialized = self.serialize(thread_data)
self.assertEqual(serialized["response_count"], 2) self.assertEqual(serialized["response_count"], 2)
def test_response_count_missing(self): def test_response_count_missing(self):
thread_data = self.make_cs_content({}) thread_data = self.make_cs_content({})
del thread_data["resp_total"] del thread_data["resp_total"]
self.register_get_thread_response(thread_data) self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"])) serialized = self.serialize(thread_data)
self.assertNotIn("response_count", serialized) self.assertNotIn("response_count", serialized)
...@@ -459,6 +460,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -459,6 +460,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
"title": "Original Title", "title": "Original Title",
"body": "Original body", "body": "Original body",
"user_id": str(self.user.id), "user_id": str(self.user.id),
"username": self.user.username,
"read": "False", "read": "False",
"endorsed": "False" "endorsed": "False"
})) }))
...@@ -480,7 +482,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -480,7 +482,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
return serializer.data return serializer.data
def test_create_minimal(self): def test_create_minimal(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id", "username": self.user.username})
saved = self.save_and_reserialize(self.minimal_data) saved = self.save_and_reserialize(self.minimal_data)
self.assertEqual( self.assertEqual(
urlparse(httpretty.last_request().path).path, urlparse(httpretty.last_request().path).path,
...@@ -500,7 +502,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -500,7 +502,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
self.assertEqual(saved["id"], "test_id") self.assertEqual(saved["id"], "test_id")
def test_create_all_fields(self): def test_create_all_fields(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id", "username": self.user.username})
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["group_id"] = 42 data["group_id"] = 42
self.save_and_reserialize(data) self.save_and_reserialize(data)
...@@ -540,7 +542,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -540,7 +542,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
) )
def test_create_type(self): def test_create_type(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id", "username": self.user.username})
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["type"] = "question" data["type"] = "question"
self.save_and_reserialize(data) self.save_and_reserialize(data)
...@@ -655,6 +657,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul ...@@ -655,6 +657,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
"thread_id": "existing_thread", "thread_id": "existing_thread",
"body": "Original body", "body": "Original body",
"user_id": str(self.user.id), "user_id": str(self.user.id),
"username": self.user.username,
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
})) }))
...@@ -685,7 +688,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul ...@@ -685,7 +688,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["parent_id"] = parent_id data["parent_id"] = parent_id
self.register_get_comment_response({"thread_id": "test_thread", "id": parent_id}) self.register_get_comment_response({"thread_id": "test_thread", "id": parent_id})
self.register_post_comment_response( self.register_post_comment_response(
{"id": "test_comment"}, {"id": "test_comment", "username": self.user.username},
thread_id="test_thread", thread_id="test_thread",
parent_id=parent_id parent_id=parent_id
) )
...@@ -712,7 +715,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul ...@@ -712,7 +715,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["endorsed"] = True data["endorsed"] = True
self.register_get_comment_response({"thread_id": "test_thread", "id": "test_parent"}) self.register_get_comment_response({"thread_id": "test_thread", "id": "test_parent"})
self.register_post_comment_response( self.register_post_comment_response(
{"id": "test_comment"}, {"id": "test_comment", "username": self.user.username},
thread_id="test_thread", thread_id="test_thread",
parent_id="test_parent" parent_id="test_parent"
) )
...@@ -807,7 +810,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul ...@@ -807,7 +810,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
def test_create_endorsed(self): def test_create_endorsed(self):
# TODO: The comments service doesn't populate the endorsement field on # TODO: The comments service doesn't populate the endorsement field on
# comment creation, so this is sadly realistic # comment creation, so this is sadly realistic
self.register_post_comment_response({}, thread_id="test_thread") self.register_post_comment_response({"username": self.user.username}, thread_id="test_thread")
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["endorsed"] = True data["endorsed"] = True
saved = self.save_and_reserialize(data) saved = self.save_and_reserialize(data)
......
""" """
Discussion API test utilities Discussion API test utilities
""" """
from contextlib import closing
from datetime import datetime
import json import json
import re import re
import hashlib
import httpretty import httpretty
from pytz import UTC
from PIL import Image
from openedx.core.djangoapps.profile_images.images import create_profile_images
from openedx.core.djangoapps.profile_images.tests.helpers import make_image_file
from openedx.core.djangoapps.user_api.accounts.image_helpers import get_profile_image_names, set_has_profile_image
def _get_thread_callback(thread_data): def _get_thread_callback(thread_data):
...@@ -392,3 +401,56 @@ def make_paginated_api_response(results=None, count=0, num_pages=0, next_link=No ...@@ -392,3 +401,56 @@ def make_paginated_api_response(results=None, count=0, num_pages=0, next_link=No
}, },
"results": results or [] "results": results or []
} }
class ProfileImageTestMixin(object):
"""
Mixin with utility methods for user profile image
"""
TEST_PROFILE_IMAGE_UPLOADED_AT = datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC)
def create_profile_image(self, user, storage):
"""
Creates profile image for user and checks that created image exists in storage
"""
with make_image_file() as image_file:
create_profile_images(image_file, get_profile_image_names(user.username))
self.check_images(user, storage)
set_has_profile_image(user.username, True, self.TEST_PROFILE_IMAGE_UPLOADED_AT)
def check_images(self, user, storage, exist=True):
"""
If exist is True, make sure the images physically exist in storage
with correct sizes and formats.
If exist is False, make sure none of the images exist.
"""
for size, name in get_profile_image_names(user.username).items():
if exist:
self.assertTrue(storage.exists(name))
with closing(Image.open(storage.path(name))) as img:
self.assertEqual(img.size, (size, size))
self.assertEqual(img.format, 'JPEG')
else:
self.assertFalse(storage.exists(name))
def get_expected_user_profile(self, username):
"""
Returns the expected user profile data for a given username
"""
url = 'http://example-storage.com/profile-images/{filename}_{{size}}.jpg?v={timestamp}'.format(
filename=hashlib.md5('secret' + username).hexdigest(),
timestamp=self.TEST_PROFILE_IMAGE_UPLOADED_AT.strftime("%s")
)
return {
'profile': {
'image': {
'has_image': True,
'image_url_full': url.format(size=500),
'image_url_large': url.format(size=120),
'image_url_medium': url.format(size=50),
'image_url_small': url.format(size=30),
}
}
}
...@@ -26,7 +26,7 @@ from discussion_api.api import ( ...@@ -26,7 +26,7 @@ from discussion_api.api import (
update_comment, update_comment,
update_thread, update_thread,
) )
from discussion_api.forms import CommentListGetForm, ThreadListGetForm, _PaginationForm from discussion_api.forms import CommentListGetForm, ThreadListGetForm, CommentGetForm
from openedx.core.lib.api.parsers import MergePatchParser from openedx.core.lib.api.parsers import MergePatchParser
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin, view_auth_classes from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin, view_auth_classes
...@@ -118,7 +118,7 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -118,7 +118,7 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
GET /api/discussion/v1/threads/?course_id=ExampleX/Demo/2015 GET /api/discussion/v1/threads/?course_id=ExampleX/Demo/2015
GET /api/discussion/v1/threads/thread_id GET /api/discussion/v1/threads/{thread_id}
POST /api/discussion/v1/threads POST /api/discussion/v1/threads
{ {
...@@ -164,10 +164,19 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -164,10 +164,19 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
* view: "unread" for threads the requesting user has not read, or * view: "unread" for threads the requesting user has not read, or
"unanswered" for question threads with no marked answer. Only one "unanswered" for question threads with no marked answer. Only one
can be selected. can be selected.
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
The topic_id, text_search, and following parameters are mutually The topic_id, text_search, and following parameters are mutually
exclusive (i.e. only one may be specified in a request) exclusive (i.e. only one may be specified in a request)
**GET Thread Parameters**:
* thread_id (required): The id of the thread
* requested_fields (optional parameter): (list) Indicates which additional
fields to return for each thread. (supports 'profile_image')
**POST Parameters**: **POST Parameters**:
* course_id (required): The course to create the thread in * course_id (required): The course to create the thread in
...@@ -279,13 +288,15 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -279,13 +288,15 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["view"], form.cleaned_data["view"],
form.cleaned_data["order_by"], form.cleaned_data["order_by"],
form.cleaned_data["order_direction"], form.cleaned_data["order_direction"],
form.cleaned_data["requested_fields"]
) )
def retrieve(self, request, thread_id=None): def retrieve(self, request, thread_id=None):
""" """
Implements the GET method for thread ID Implements the GET method for thread ID
""" """
return Response(get_thread(request, thread_id)) requested_fields = request.GET.get('requested_fields')
return Response(get_thread(request, thread_id, requested_fields))
def create(self, request): def create(self, request):
""" """
...@@ -351,6 +362,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -351,6 +362,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100) * page_size: The number of items per page (default is 10, max is 100)
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
**GET Child Comment List Parameters**: **GET Child Comment List Parameters**:
* comment_id (required): The comment to retrieve child comments for * comment_id (required): The comment to retrieve child comments for
...@@ -359,6 +373,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -359,6 +373,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100) * page_size: The number of items per page (default is 10, max is 100)
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
**POST Parameters**: **POST Parameters**:
...@@ -453,21 +470,23 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -453,21 +470,23 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["thread_id"], form.cleaned_data["thread_id"],
form.cleaned_data["endorsed"], form.cleaned_data["endorsed"],
form.cleaned_data["page"], form.cleaned_data["page"],
form.cleaned_data["page_size"] form.cleaned_data["page_size"],
form.cleaned_data["requested_fields"],
) )
def retrieve(self, request, comment_id=None): def retrieve(self, request, comment_id=None):
""" """
Implements the GET method for comments against response ID Implements the GET method for comments against response ID
""" """
form = _PaginationForm(request.GET) form = CommentGetForm(request.GET)
if not form.is_valid(): if not form.is_valid():
raise ValidationError(form.errors) raise ValidationError(form.errors)
return get_response_comments( return get_response_comments(
request, request,
comment_id, comment_id,
form.cleaned_data["page"], form.cleaned_data["page"],
form.cleaned_data["page_size"] form.cleaned_data["page_size"],
form.cleaned_data["requested_fields"],
) )
def create(self, request): def create(self, request):
......
...@@ -66,7 +66,7 @@ def learner_profile_context(request, profile_username, user_is_staff): ...@@ -66,7 +66,7 @@ def learner_profile_context(request, profile_username, user_is_staff):
own_profile = (logged_in_user.username == profile_username) own_profile = (logged_in_user.username == profile_username)
account_settings_data = get_account_settings(request, profile_username) account_settings_data = get_account_settings(request, [profile_username])[0]
preferences_data = get_user_preferences(profile_user, profile_username) preferences_data = get_user_preferences(profile_user, profile_username)
......
...@@ -1627,7 +1627,7 @@ class TestSubmitPhotosForVerification(TestCase): ...@@ -1627,7 +1627,7 @@ class TestSubmitPhotosForVerification(TestCase):
""" """
request = RequestFactory().get('/url') request = RequestFactory().get('/url')
request.user = self.user request.user = self.user
account_settings = get_account_settings(request) account_settings = get_account_settings(request)[0]
self.assertEqual(account_settings['name'], full_name) self.assertEqual(account_settings['name'], full_name)
def _get_post_data(self): def _get_post_data(self):
......
...@@ -41,7 +41,7 @@ visible_fields = _visible_fields ...@@ -41,7 +41,7 @@ visible_fields = _visible_fields
@intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError]) @intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError])
def get_account_settings(request, username=None, configuration=None, view=None): def get_account_settings(request, usernames=None, configuration=None, view=None):
"""Returns account information for a user serialized as JSON. """Returns account information for a user serialized as JSON.
Note: Note:
...@@ -52,8 +52,8 @@ def get_account_settings(request, username=None, configuration=None, view=None): ...@@ -52,8 +52,8 @@ def get_account_settings(request, username=None, configuration=None, view=None):
request (Request): The request object with account information about the requesting user. request (Request): The request object with account information about the requesting user.
Only the user with username `username` or users with "is_staff" privileges can get full Only the user with username `username` or users with "is_staff" privileges can get full
account information. Other users will get the account fields that the user has elected to share. account information. Other users will get the account fields that the user has elected to share.
username (str): Optional username for the desired account information. If not specified, usernames (list): Optional list of usernames for the desired account information. If not
`request.user.username` is assumed. specified, `request.user.username` is assumed.
configuration (dict): an optional configuration specifying which fields in the account configuration (dict): an optional configuration specifying which fields in the account
can be shared, and the default visibility settings. If not present, the setting value with can be shared, and the default visibility settings. If not present, the setting value with
key ACCOUNT_VISIBILITY_CONFIGURATION is used. key ACCOUNT_VISIBILITY_CONFIGURATION is used.
...@@ -62,7 +62,7 @@ def get_account_settings(request, username=None, configuration=None, view=None): ...@@ -62,7 +62,7 @@ def get_account_settings(request, username=None, configuration=None, view=None):
"shared", only shared account information will be returned, regardless of `request.user`. "shared", only shared account information will be returned, regardless of `request.user`.
Returns: Returns:
A dict containing account fields. A list of users account details.
Raises: Raises:
UserNotFound: no user with username `username` exists (or `request.user.username` if UserNotFound: no user with username `username` exists (or `request.user.username` if
...@@ -70,27 +70,27 @@ def get_account_settings(request, username=None, configuration=None, view=None): ...@@ -70,27 +70,27 @@ def get_account_settings(request, username=None, configuration=None, view=None):
UserAPIInternalError: the operation failed due to an unexpected error. UserAPIInternalError: the operation failed due to an unexpected error.
""" """
requesting_user = request.user requesting_user = request.user
usernames = usernames or [requesting_user.username]
if username is None: requested_users = User.objects.select_related('profile').filter(username__in=usernames)
username = requesting_user.username if not requested_users:
try:
existing_user = User.objects.select_related('profile').get(username=username)
except ObjectDoesNotExist:
raise UserNotFound() raise UserNotFound()
has_full_access = requesting_user.username == username or requesting_user.is_staff serialized_users = []
if has_full_access and view != 'shared': for user in requested_users:
admin_fields = settings.ACCOUNT_VISIBILITY_CONFIGURATION.get('admin_fields') has_full_access = requesting_user.is_staff or requesting_user.username == user.username
else: if has_full_access and view != 'shared':
admin_fields = None admin_fields = settings.ACCOUNT_VISIBILITY_CONFIGURATION.get('admin_fields')
else:
return UserReadOnlySerializer( admin_fields = None
existing_user, serialized_users.append(UserReadOnlySerializer(
configuration=configuration, user,
custom_fields=admin_fields, configuration=configuration,
context={'request': request} custom_fields=admin_fields,
).data context={'request': request}
).data)
return serialized_users
@intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError]) @intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError])
......
...@@ -57,13 +57,13 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): ...@@ -57,13 +57,13 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
def test_get_username_provided(self): def test_get_username_provided(self):
"""Test the difference in behavior when a username is supplied to get_account_settings.""" """Test the difference in behavior when a username is supplied to get_account_settings."""
account_settings = get_account_settings(self.default_request) account_settings = get_account_settings(self.default_request)[0]
self.assertEqual(self.user.username, account_settings["username"]) self.assertEqual(self.user.username, account_settings["username"])
account_settings = get_account_settings(self.default_request, username=self.user.username) account_settings = get_account_settings(self.default_request, usernames=[self.user.username])[0]
self.assertEqual(self.user.username, account_settings["username"]) self.assertEqual(self.user.username, account_settings["username"])
account_settings = get_account_settings(self.default_request, username=self.different_user.username) account_settings = get_account_settings(self.default_request, usernames=[self.different_user.username])[0]
self.assertEqual(self.different_user.username, account_settings["username"]) self.assertEqual(self.different_user.username, account_settings["username"])
def test_get_configuration_provided(self): def test_get_configuration_provided(self):
...@@ -81,20 +81,20 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): ...@@ -81,20 +81,20 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
} }
# With default configuration settings, email is not shared with other (non-staff) users. # With default configuration settings, email is not shared with other (non-staff) users.
account_settings = get_account_settings(self.default_request, self.different_user.username) account_settings = get_account_settings(self.default_request, [self.different_user.username])[0]
self.assertNotIn("email", account_settings) self.assertNotIn("email", account_settings)
account_settings = get_account_settings( account_settings = get_account_settings(
self.default_request, self.default_request,
self.different_user.username, [self.different_user.username],
configuration=config, configuration=config,
) )[0]
self.assertEqual(self.different_user.email, account_settings["email"]) self.assertEqual(self.different_user.email, account_settings["email"])
def test_get_user_not_found(self): def test_get_user_not_found(self):
"""Test that UserNotFound is thrown if there is no user with username.""" """Test that UserNotFound is thrown if there is no user with username."""
with self.assertRaises(UserNotFound): with self.assertRaises(UserNotFound):
get_account_settings(self.default_request, username="does_not_exist") get_account_settings(self.default_request, usernames=["does_not_exist"])
self.user.username = "does_not_exist" self.user.username = "does_not_exist"
request = self.request_factory.get("/api/user/v1/accounts/") request = self.request_factory.get("/api/user/v1/accounts/")
...@@ -105,11 +105,11 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): ...@@ -105,11 +105,11 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
def test_update_username_provided(self): def test_update_username_provided(self):
"""Test the difference in behavior when a username is supplied to update_account_settings.""" """Test the difference in behavior when a username is supplied to update_account_settings."""
update_account_settings(self.user, {"name": "Mickey Mouse"}) update_account_settings(self.user, {"name": "Mickey Mouse"})
account_settings = get_account_settings(self.default_request) account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Mickey Mouse", account_settings["name"]) self.assertEqual("Mickey Mouse", account_settings["name"])
update_account_settings(self.user, {"name": "Donald Duck"}, username=self.user.username) update_account_settings(self.user, {"name": "Donald Duck"}, username=self.user.username)
account_settings = get_account_settings(self.default_request) account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Donald Duck", account_settings["name"]) self.assertEqual("Donald Duck", account_settings["name"])
with self.assertRaises(UserNotAuthorized): with self.assertRaises(UserNotAuthorized):
...@@ -189,7 +189,7 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): ...@@ -189,7 +189,7 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
self.assertIn("Error thrown from do_email_change_request", context_manager.exception.developer_message) self.assertIn("Error thrown from do_email_change_request", context_manager.exception.developer_message)
# Verify that the name change happened, even though the attempt to send the email failed. # Verify that the name change happened, even though the attempt to send the email failed.
account_settings = get_account_settings(self.default_request) account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Mickey Mouse", account_settings["name"]) self.assertEqual("Mickey Mouse", account_settings["name"])
@patch('openedx.core.djangoapps.user_api.accounts.serializers.AccountUserSerializer.save') @patch('openedx.core.djangoapps.user_api.accounts.serializers.AccountUserSerializer.save')
...@@ -255,7 +255,7 @@ class AccountSettingsOnCreationTest(TestCase): ...@@ -255,7 +255,7 @@ class AccountSettingsOnCreationTest(TestCase):
user = User.objects.get(username=self.USERNAME) user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get("/api/user/v1/accounts/") request = RequestFactory().get("/api/user/v1/accounts/")
request.user = user request.user = user
account_settings = get_account_settings(request) account_settings = get_account_settings(request)[0]
# Expect a date joined field but remove it to simplify the following comparison # Expect a date joined field but remove it to simplify the following comparison
self.assertIsNotNone(account_settings['date_joined']) self.assertIsNotNone(account_settings['date_joined'])
...@@ -341,14 +341,14 @@ class AccountCreationActivationAndPasswordChangeTest(TestCase): ...@@ -341,14 +341,14 @@ class AccountCreationActivationAndPasswordChangeTest(TestCase):
request = RequestFactory().get("/api/user/v1/accounts/") request = RequestFactory().get("/api/user/v1/accounts/")
request.user = user request.user = user
account = get_account_settings(request) account = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account["username"]) self.assertEqual(self.USERNAME, account["username"])
self.assertEqual(self.EMAIL, account["email"]) self.assertEqual(self.EMAIL, account["email"])
self.assertFalse(account["is_active"]) self.assertFalse(account["is_active"])
# Activate the account and verify that it is now active # Activate the account and verify that it is now active
activate_account(activation_key) activate_account(activation_key)
account = get_account_settings(request) account = get_account_settings(request)[0]
self.assertTrue(account['is_active']) self.assertTrue(account['is_active'])
def test_create_account_duplicate_username(self): def test_create_account_duplicate_username(self):
......
...@@ -10,6 +10,7 @@ from rest_framework import permissions ...@@ -10,6 +10,7 @@ from rest_framework import permissions
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
from openedx.core.lib.api.authentication import ( from openedx.core.lib.api.authentication import (
SessionAuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser,
...@@ -20,7 +21,7 @@ from .api import get_account_settings, update_account_settings ...@@ -20,7 +21,7 @@ from .api import get_account_settings, update_account_settings
from ..errors import UserNotFound, UserNotAuthorized, AccountUpdateError, AccountValidationError from ..errors import UserNotFound, UserNotAuthorized, AccountUpdateError, AccountValidationError
class AccountView(APIView): class AccountViewSet(ViewSet):
""" """
**Use Cases** **Use Cases**
...@@ -29,6 +30,7 @@ class AccountView(APIView): ...@@ -29,6 +30,7 @@ class AccountView(APIView):
**Example Requests** **Example Requests**
GET /api/user/v1/accounts?usernames={username1,username2}[?view=shared]
GET /api/user/v1/accounts/{username}/[?view=shared] GET /api/user/v1/accounts/{username}/[?view=shared]
PATCH /api/user/v1/accounts/{username}/{"key":"value"} "application/merge-patch+json" PATCH /api/user/v1/accounts/{username}/{"key":"value"} "application/merge-patch+json"
...@@ -146,19 +148,34 @@ class AccountView(APIView): ...@@ -146,19 +148,34 @@ class AccountView(APIView):
permission_classes = (permissions.IsAuthenticated,) permission_classes = (permissions.IsAuthenticated,)
parser_classes = (MergePatchParser,) parser_classes = (MergePatchParser,)
def get(self, request, username): def list(self, request):
""" """
GET /api/user/v1/accounts/{username}/ GET /api/user/v1/accounts?username={username1,username2}
""" """
usernames = request.GET.get('username')
try: try:
if usernames:
usernames = usernames.strip(',').split(',')
account_settings = get_account_settings( account_settings = get_account_settings(
request, username, view=request.query_params.get('view')) request, usernames, view=request.query_params.get('view'))
except UserNotFound: except UserNotFound:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND) return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
return Response(account_settings) return Response(account_settings)
def patch(self, request, username): def retrieve(self, request, username):
"""
GET /api/user/v1/accounts/{username}/
"""
try:
account_settings = get_account_settings(
request, [username], view=request.query_params.get('view'))
except UserNotFound:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
return Response(account_settings[0])
def partial_update(self, request, username):
""" """
PATCH /api/user/v1/accounts/{username}/ PATCH /api/user/v1/accounts/{username}/
...@@ -169,7 +186,7 @@ class AccountView(APIView): ...@@ -169,7 +186,7 @@ class AccountView(APIView):
try: try:
with transaction.atomic(): with transaction.atomic():
update_account_settings(request.user, request.data, username=username) update_account_settings(request.user, request.data, username=username)
account_settings = get_account_settings(request, username) account_settings = get_account_settings(request, [username])[0]
except UserNotAuthorized: except UserNotAuthorized:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND) return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
except UserNotFound: except UserNotFound:
......
...@@ -1366,7 +1366,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1366,7 +1366,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME) user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url') request = RequestFactory().get('/url')
request.user = user request.user = user
account_settings = get_account_settings(request) account_settings = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account_settings["username"]) self.assertEqual(self.USERNAME, account_settings["username"])
self.assertEqual(self.EMAIL, account_settings["email"]) self.assertEqual(self.EMAIL, account_settings["email"])
...@@ -1406,7 +1406,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1406,7 +1406,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME) user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url') request = RequestFactory().get('/url')
request.user = user request.user = user
account_settings = get_account_settings(request) account_settings = get_account_settings(request)[0]
self.assertEqual(account_settings["level_of_education"], self.EDUCATION) self.assertEqual(account_settings["level_of_education"], self.EDUCATION)
self.assertEqual(account_settings["mailing_address"], self.ADDRESS) self.assertEqual(account_settings["mailing_address"], self.ADDRESS)
...@@ -1440,7 +1440,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1440,7 +1440,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME) user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url') request = RequestFactory().get('/url')
request.user = user request.user = user
account_settings = get_account_settings(request) account_settings = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account_settings["username"]) self.assertEqual(self.USERNAME, account_settings["username"])
self.assertEqual(self.EMAIL, account_settings["email"]) self.assertEqual(self.EMAIL, account_settings["email"])
......
...@@ -6,16 +6,23 @@ from django.conf import settings ...@@ -6,16 +6,23 @@ from django.conf import settings
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from ..profile_images.views import ProfileImageView from ..profile_images.views import ProfileImageView
from .accounts.views import AccountView from .accounts.views import AccountViewSet
from .preferences.views import PreferencesView, PreferencesDetailView from .preferences.views import PreferencesView, PreferencesDetailView
ACCOUNT_LIST = AccountViewSet.as_view({
'get': 'list',
})
ACCOUNT_DETAIL = AccountViewSet.as_view({
'get': 'retrieve',
'patch': 'partial_update',
})
urlpatterns = patterns( urlpatterns = patterns(
'', '',
url( url(r'^v1/accounts/{}$'.format(settings.USERNAME_PATTERN), ACCOUNT_DETAIL, name='accounts_api'),
r'^v1/accounts/{}$'.format(settings.USERNAME_PATTERN), url(r'^v1/accounts$', ACCOUNT_LIST, name='accounts_detail_api'),
AccountView.as_view(),
name="accounts_api"
),
url( url(
r'^v1/accounts/{}/image$'.format(settings.USERNAME_PATTERN), r'^v1/accounts/{}/image$'.format(settings.USERNAME_PATTERN),
ProfileImageView.as_view(), ProfileImageView.as_view(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment