Commit f19af520 by wajeeha-khalid

MA-2050: add optional user profile details in GET thread & comment

parent c681c567
......@@ -9,6 +9,8 @@ from django.core.exceptions import ValidationError
from django.core.urlresolvers import reverse
from django.http import Http404
import itertools
from enum import Enum
from openedx.core.djangoapps.user_api.accounts.views import AccountViewSet
from rest_framework.exceptions import PermissionDenied
......@@ -61,6 +63,14 @@ class DiscussionTopic(object):
self.children = children or [] # children are of same type i.e. DiscussionTopic
class DiscussionEntity(Enum):
"""
Enum for different types of discussion related entities
"""
thread = 'thread'
comment = 'comment'
def _get_course(course_key, user):
"""
Get the course descriptor, raising CourseNotFoundError if the course is not found or
......@@ -318,6 +328,138 @@ def get_course_topics(request, course_key, topic_ids=None):
}
def _get_user_profile_dict(request, usernames):
"""
Gets user profile details for a list of usernames and creates a dictionary with
profile details against username.
Parameters:
request: The django request object.
usernames: A string of comma separated usernames.
Returns:
A dict with username as key and user profile details as value.
"""
request.GET = request.GET.copy() # Make a mutable copy of the GET parameters.
request.GET['username'] = usernames
user_profile_details = AccountViewSet.as_view({'get': 'list'})(request).data
return {user['username']: user for user in user_profile_details}
def _user_profile(user_profile):
"""
Returns the user profile object. For now, this just comprises the
profile_image details.
"""
return {
'profile': {
'image': user_profile['profile_image']
}
}
def _get_users(discussion_entity_type, discussion_entity, username_profile_dict):
"""
Returns users with profile details for given discussion thread/comment.
Parameters:
discussion_entity_type: DiscussionEntity Enum value for Thread or Comment.
discussion_entity: Serialized thread/comment.
username_profile_dict: A dict with user profile details against username.
Returns:
A dict of users with username as key and user profile details as value.
"""
users = {discussion_entity['author']: _user_profile(username_profile_dict[discussion_entity['author']])}
if discussion_entity_type == DiscussionEntity.comment and discussion_entity['endorsed']:
users[discussion_entity['endorsed_by']] = _user_profile(username_profile_dict[discussion_entity['endorsed_by']])
return users
def _add_additional_response_fields(
request, serialized_discussion_entities, usernames, discussion_entity_type, include_profile_image
):
"""
Adds additional data to serialized discussion thread/comment.
Parameters:
request: The django request object.
serialized_discussion_entities: A list of serialized Thread/Comment.
usernames: A list of usernames involved in threads/comments (e.g. as author or as comment endorser).
discussion_entity_type: DiscussionEntity Enum value for Thread or Comment.
include_profile_image: (boolean) True if requested_fields has 'profile_image' else False.
Returns:
A list of serialized discussion thread/comment with additional data if requested.
"""
if include_profile_image:
username_profile_dict = _get_user_profile_dict(request, usernames=','.join(usernames))
for discussion_entity in serialized_discussion_entities:
discussion_entity['users'] = _get_users(discussion_entity_type, discussion_entity, username_profile_dict)
return serialized_discussion_entities
def _include_profile_image(requested_fields):
"""
Returns True if requested_fields list has 'profile_image' entity else False
"""
return requested_fields and 'profile_image' in requested_fields
def _serialize_discussion_entities(request, context, discussion_entities, requested_fields, discussion_entity_type):
"""
It serializes Discussion Entity (Thread or Comment) and add additional data if requested.
For a given list of Thread/Comment; it serializes and add additional information to the
object as per requested_fields list (i.e. profile_image).
Parameters:
request: The django request object
context: The context appropriate for use with the thread or comment
discussion_entities: List of Thread or Comment objects
requested_fields: Indicates which additional fields to return
for each thread.
discussion_entity_type: DiscussionEntity Enum value for Thread or Comment
Returns:
A list of serialized discussion entities
"""
results = []
usernames = []
include_profile_image = _include_profile_image(requested_fields)
for entity in discussion_entities:
if discussion_entity_type == DiscussionEntity.thread:
serialized_entity = ThreadSerializer(entity, context=context).data
elif discussion_entity_type == DiscussionEntity.comment:
serialized_entity = CommentSerializer(entity, context=context).data
results.append(serialized_entity)
if include_profile_image:
if serialized_entity['author'] not in usernames:
usernames.append(serialized_entity['author'])
if (
'endorsed' in serialized_entity and serialized_entity['endorsed'] and
'endorsed_by' in serialized_entity and serialized_entity['endorsed_by'] not in usernames
):
usernames.append(serialized_entity['endorsed_by'])
results = _add_additional_response_fields(
request, results, usernames, discussion_entity_type, include_profile_image
)
return results
def get_thread_list(
request,
course_key,
......@@ -329,6 +471,7 @@ def get_thread_list(
view=None,
order_by="last_activity_at",
order_direction="desc",
requested_fields=None,
):
"""
Return the list of all discussion threads pertaining to the given course
......@@ -348,6 +491,8 @@ def get_thread_list(
"last_activity_at".
order_direction: The direction in which to sort the threads by. The only
values are "asc" or "desc". The default is "desc".
requested_fields: Indicates which additional fields to return
for each thread. (i.e. ['profile_image'])
Note that topic_id_list, text_search, and following are mutually exclusive.
......@@ -418,7 +563,9 @@ def get_thread_list(
if paginated_results.page != page:
raise PageNotFoundError("Page not found (No results on this page).")
results = [ThreadSerializer(thread, context=context).data for thread in paginated_results.collection]
results = _serialize_discussion_entities(
request, context, paginated_results.collection, requested_fields, DiscussionEntity.thread
)
paginator = DiscussionAPIPagination(
request,
......@@ -432,7 +579,7 @@ def get_thread_list(
})
def get_comment_list(request, thread_id, endorsed, page, page_size):
def get_comment_list(request, thread_id, endorsed, page, page_size, requested_fields=None):
"""
Return the list of comments in the given thread.
......@@ -451,6 +598,9 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
page_size: The number of comments to retrieve per page
requested_fields: Indicates which additional fields to return for
each comment. (i.e. ['profile_image'])
Returns:
A paginated result containing a list of comments; see
......@@ -497,7 +647,8 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
raise PageNotFoundError("Page not found (No results on this page).")
num_pages = (resp_total + page_size - 1) / page_size if resp_total else 1
results = [CommentSerializer(response, context=context).data for response in responses]
results = _serialize_discussion_entities(request, context, responses, requested_fields, DiscussionEntity.comment)
paginator = DiscussionAPIPagination(request, page, num_pages, resp_total)
return paginator.get_paginated_response(results)
......@@ -791,7 +942,7 @@ def update_comment(request, comment_id, update_data):
return api_comment
def get_thread(request, thread_id):
def get_thread(request, thread_id, requested_fields=None):
"""
Retrieve a thread.
......@@ -802,17 +953,18 @@ def get_thread(request, thread_id):
thread_id: The id for the thread to retrieve
requested_fields: Indicates which additional fields to return for
thread. (i.e. ['profile_image'])
"""
cc_thread, context = _get_thread_and_context(
request,
thread_id,
retrieve_kwargs={"user_id": unicode(request.user.id)}
)
serializer = ThreadSerializer(cc_thread, context=context)
return serializer.data
return _serialize_discussion_entities(request, context, [cc_thread], requested_fields, DiscussionEntity.thread)[0]
def get_response_comments(request, comment_id, page, page_size):
def get_response_comments(request, comment_id, page, page_size, requested_fields=None):
"""
Return the list of comments for the given thread response.
......@@ -827,6 +979,9 @@ def get_response_comments(request, comment_id, page, page_size):
page_size: The number of comments to retrieve per page
requested_fields: Indicates which additional fields to return for
each child comment. (i.e. ['profile_image'])
Returns:
A paginated result containing a list of comments
......@@ -856,7 +1011,9 @@ def get_response_comments(request, comment_id, page, page_size):
if len(paged_response_comments) == 0 and page != 1:
raise PageNotFoundError("Page not found (No results on this page).")
results = [CommentSerializer(comment, context=context).data for comment in paged_response_comments]
results = _serialize_discussion_entities(
request, context, paged_response_comments, requested_fields, DiscussionEntity.comment
)
comments_count = len(response_comments)
num_pages = (comments_count + page_size - 1) / page_size if comments_count else 1
......
......@@ -51,6 +51,7 @@ class ThreadListGetForm(_PaginationForm):
choices=[(choice, choice) for choice in ["asc", "desc"]],
required=False
)
requested_fields = MultiValueField(required=False)
def clean_order_by(self):
"""Return a default choice"""
......@@ -106,6 +107,7 @@ class CommentListGetForm(_PaginationForm):
"""
thread_id = CharField()
endorsed = ExtendedNullBooleanField(required=False)
requested_fields = MultiValueField(required=False)
class CommentActionsForm(Form):
......@@ -115,3 +117,10 @@ class CommentActionsForm(Form):
"""
voted = BooleanField(required=False)
abuse_flagged = BooleanField(required=False)
class CommentGetForm(_PaginationForm):
"""
A form to validate query parameters in the comment retrieval endpoint
"""
requested_fields = MultiValueField(required=False)
......@@ -40,7 +40,8 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
make_minimal_cs_thread,
make_paginated_api_response
make_paginated_api_response,
ProfileImageTestMixin,
)
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
......@@ -1148,7 +1149,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_basic_query_params(self):
self.get_comment_list(
self.make_minimal_cs_thread({
"children": [make_minimal_cs_comment()],
"children": [make_minimal_cs_comment({"username": self.user.username})],
"resp_total": 71
}),
page=6,
......@@ -1254,8 +1255,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_question_content(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment"})],
"non_endorsed_responses": [make_minimal_cs_comment({"id": "non_endorsed_comment"})],
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment", "username": self.user.username})],
"non_endorsed_responses": [make_minimal_cs_comment({
"id": "non_endorsed_comment", "username": self.user.username
})],
"non_endorsed_resp_total": 1,
})
......@@ -1274,7 +1277,8 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
"anonymous": True,
"children": [
make_minimal_cs_comment({
"endorsement": {"user_id": str(self.author.id), "time": "2015-05-18T12:34:56Z"}
"username": self.user.username,
"endorsement": {"user_id": str(self.author.id), "time": "2015-05-18T12:34:56Z"},
})
]
})
......@@ -1301,7 +1305,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
# number of responses is unrealistic but convenient for this test
thread = self.make_minimal_cs_thread({
"thread_type": thread_type,
response_field: [make_minimal_cs_comment()],
response_field: [make_minimal_cs_comment({"username": self.user.username})],
response_total_field: 5,
})
......@@ -1337,9 +1341,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_question_endorsed_pagination(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [
make_minimal_cs_comment({"id": "comment_{}".format(i)}) for i in range(10)
]
"endorsed_responses": [make_minimal_cs_comment({
"id": "comment_{}".format(i),
"username": self.user.username
}) for i in range(10)]
})
def assert_page_correct(page, page_size, expected_start, expected_stop, expected_next, expected_prev):
......@@ -1553,7 +1558,7 @@ class CreateThreadTest(
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
self.register_post_thread_response({})
self.register_post_thread_response({"username": self.user.username})
data = self.minimal_data.copy()
data["course_id"] = unicode(cohort_course.id)
if data_group_state == "group_is_none":
......@@ -1582,7 +1587,7 @@ class CreateThreadTest(
self.fail("Unexpected validation error: {}".format(ex))
def test_following(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_subscription_response(self.user)
data = self.minimal_data.copy()
data["following"] = "True"
......@@ -1600,7 +1605,7 @@ class CreateThreadTest(
)
def test_voted(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_thread_votes_response("test_id")
data = self.minimal_data.copy()
data["voted"] = "True"
......@@ -1616,7 +1621,7 @@ class CreateThreadTest(
)
def test_abuse_flagged(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_thread_flag_response("test_id")
data = self.minimal_data.copy()
data["abuse_flagged"] = "True"
......@@ -1802,7 +1807,7 @@ class CreateCommentTest(
"user_id": str(self.user.id) if is_thread_author else str(self.user.id + 1),
})
)
self.register_post_comment_response({}, "test_thread")
self.register_post_comment_response({"username": self.user.username}, "test_thread")
data = self.minimal_data.copy()
data["endorsed"] = True
expected_error = (
......@@ -1817,7 +1822,7 @@ class CreateCommentTest(
self.assertTrue(expected_error)
def test_voted(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread")
self.register_post_comment_response({"id": "test_comment", "username": self.user.username}, "test_thread")
self.register_comment_votes_response("test_comment")
data = self.minimal_data.copy()
data["voted"] = "True"
......@@ -1833,7 +1838,7 @@ class CreateCommentTest(
)
def test_abuse_flagged(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread")
self.register_post_comment_response({"id": "test_comment", "username": self.user.username}, "test_thread")
self.register_comment_flag_response("test_comment")
data = self.minimal_data.copy()
data["abuse_flagged"] = "True"
......@@ -1906,7 +1911,7 @@ class CreateCommentTest(
cohort.id + 1
),
}))
self.register_post_comment_response({}, thread_id="cohort_thread")
self.register_post_comment_response({"username": self.user.username}, thread_id="cohort_thread")
data = self.minimal_data.copy()
data["thread_id"] = "cohort_thread"
expected_error = (
......
......@@ -69,6 +69,7 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"view": "",
"order_by": "last_activity_at",
"order_direction": "desc",
"requested_fields": set(),
}
)
......@@ -154,6 +155,14 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
self.form_data[field] = value
self.assert_field_value(field, value)
def test_requested_fields(self):
self.form_data["requested_fields"] = "profile_image"
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["requested_fields"],
{"profile_image"},
)
@ddt.ddt
class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
......@@ -177,7 +186,8 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"thread_id": "deadbeef",
"endorsed": False,
"page": 2,
"page_size": 13
"page_size": 13,
"requested_fields": set(),
}
)
......@@ -202,3 +212,11 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
def test_invalid_endorsed(self):
self.form_data["endorsed"] = "invalid-boolean"
self.assert_error("endorsed", "Invalid Boolean Value.")
def test_requested_fields(self):
self.form_data["requested_fields"] = {"profile_image"}
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["requested_fields"],
{"profile_image"},
)
......@@ -16,6 +16,7 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_thread,
make_minimal_cs_comment,
ProfileImageTestMixin,
)
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
......@@ -237,7 +238,7 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe
thread_data = self.make_cs_content({})
del thread_data["pinned"]
self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"]))
serialized = self.serialize(thread_data)
self.assertEqual(serialized["pinned"], False)
def test_group(self):
......@@ -255,14 +256,14 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe
def test_response_count(self):
thread_data = self.make_cs_content({"resp_total": 2})
self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"]))
serialized = self.serialize(thread_data)
self.assertEqual(serialized["response_count"], 2)
def test_response_count_missing(self):
thread_data = self.make_cs_content({})
del thread_data["resp_total"]
self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"]))
serialized = self.serialize(thread_data)
self.assertNotIn("response_count", serialized)
......@@ -459,6 +460,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
"title": "Original Title",
"body": "Original body",
"user_id": str(self.user.id),
"username": self.user.username,
"read": "False",
"endorsed": "False"
}))
......@@ -480,7 +482,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
return serializer.data
def test_create_minimal(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
saved = self.save_and_reserialize(self.minimal_data)
self.assertEqual(
urlparse(httpretty.last_request().path).path,
......@@ -500,7 +502,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
self.assertEqual(saved["id"], "test_id")
def test_create_all_fields(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
data = self.minimal_data.copy()
data["group_id"] = 42
self.save_and_reserialize(data)
......@@ -540,7 +542,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
)
def test_create_type(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
data = self.minimal_data.copy()
data["type"] = "question"
self.save_and_reserialize(data)
......@@ -655,6 +657,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
"thread_id": "existing_thread",
"body": "Original body",
"user_id": str(self.user.id),
"username": self.user.username,
"course_id": unicode(self.course.id),
}))
......@@ -685,7 +688,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["parent_id"] = parent_id
self.register_get_comment_response({"thread_id": "test_thread", "id": parent_id})
self.register_post_comment_response(
{"id": "test_comment"},
{"id": "test_comment", "username": self.user.username},
thread_id="test_thread",
parent_id=parent_id
)
......@@ -712,7 +715,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["endorsed"] = True
self.register_get_comment_response({"thread_id": "test_thread", "id": "test_parent"})
self.register_post_comment_response(
{"id": "test_comment"},
{"id": "test_comment", "username": self.user.username},
thread_id="test_thread",
parent_id="test_parent"
)
......@@ -807,7 +810,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
def test_create_endorsed(self):
# TODO: The comments service doesn't populate the endorsement field on
# comment creation, so this is sadly realistic
self.register_post_comment_response({}, thread_id="test_thread")
self.register_post_comment_response({"username": self.user.username}, thread_id="test_thread")
data = self.minimal_data.copy()
data["endorsed"] = True
saved = self.save_and_reserialize(data)
......
......@@ -9,6 +9,7 @@ import ddt
import httpretty
import mock
from nose.plugins.attrib import attr
from openedx.core.djangoapps.user_api.accounts.image_helpers import get_profile_image_storage
from pytz import UTC
from django.core.urlresolvers import reverse
......@@ -24,8 +25,8 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
make_minimal_cs_thread,
make_paginated_api_response
)
make_paginated_api_response,
ProfileImageTestMixin)
from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin, PatchMediaTypeMixin
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
......@@ -54,6 +55,9 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
)
self.password = "password"
self.user = UserFactory.create(password=self.password)
# Ensure that parental controls don't apply to this user
self.user.profile.year_of_birth = 1970
self.user.profile.save()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.client.login(username=self.user.username, password=self.password)
......@@ -306,7 +310,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
@ddt.ddt
@httpretty.activate
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase, ProfileImageTestMixin):
"""Tests for ThreadViewSet list"""
def setUp(self):
super(ThreadViewSetListTest, self).setUp()
......@@ -627,6 +631,82 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
}
)
def test_profile_image_requested_field(self):
"""
Tests thread has user profile image details if called in requested_fields
"""
user_2 = UserFactory.create(password=self.password)
# Ensure that parental controls don't apply to this user
user_2.profile.year_of_birth = 1970
user_2.profile.save()
source_threads = [
{
"type": "thread",
"id": "test_thread",
"course_id": unicode(self.course.id),
"commentable_id": "test_topic",
"group_id": None,
"user_id": str(self.user.id),
"username": self.user.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"thread_type": "discussion",
"title": "Test Title",
"body": "Test body",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 4},
"comments_count": 5,
"unread_comments_count": 3,
"read": False,
"endorsed": False
},
{
"type": "thread",
"id": "test_thread",
"course_id": unicode(self.course.id),
"commentable_id": "test_topic",
"group_id": None,
"user_id": str(user_2.id),
"username": user_2.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"thread_type": "discussion",
"title": "Test Title",
"body": "Test body",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 4},
"comments_count": 5,
"unread_comments_count": 3,
"read": False,
"endorsed": False
}
]
self.register_get_user_response(self.user, upvoted_ids=["test_thread"])
self.register_get_threads_response(source_threads, page=1, num_pages=1)
self.create_profile_image(self.user, get_profile_image_storage())
self.create_profile_image(user_2, get_profile_image_storage())
response = self.client.get(
self.url,
{"course_id": unicode(self.course.id), "requested_fields": "profile_image"},
)
self.assertEqual(response.status_code, 200)
response_threads = json.loads(response.content)['results']
for response_thread in response_threads:
expected_profile_data = self.get_expected_user_profile(response_thread['author'])
response_users = response_thread['users']
self.assertEqual(expected_profile_data, response_users[response_thread['author']])
@httpretty.activate
@disable_signal(api, 'thread_created')
......@@ -977,13 +1057,14 @@ class ThreadViewSetDeleteTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
@ddt.ddt
@httpretty.activate
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase, ProfileImageTestMixin):
"""Tests for CommentViewSet list"""
def setUp(self):
super(CommentViewSetListTest, self).setUp()
self.author = UserFactory.create()
self.url = reverse("comment-list")
self.thread_id = "test_thread"
self.storage = get_profile_image_storage()
def make_minimal_cs_thread(self, overrides=None):
"""
......@@ -1142,8 +1223,16 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
self.register_get_user_response(self.user)
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment"})],
"non_endorsed_responses": [make_minimal_cs_comment({"id": "non_endorsed_comment"})],
"endorsed_responses": [make_minimal_cs_comment({
"id": "endorsed_comment",
"user_id": self.user.id,
"username": self.user.username,
})],
"non_endorsed_responses": [make_minimal_cs_comment({
"id": "non_endorsed_comment",
"user_id": self.user.id,
"username": self.user.username,
})],
"non_endorsed_resp_total": 1,
})
self.register_get_thread_response(thread)
......@@ -1233,6 +1322,91 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
}
)
def test_profile_image_requested_field(self):
"""
Tests all comments retrieved have user profile image details if called in requested_fields
"""
source_comments = [{
"type": "comment",
"id": "test_comment",
"thread_id": self.thread_id,
"parent_id": None,
"user_id": str(self.user.id),
"username": self.user.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"body": "Test body",
"endorsed": False,
"abuse_flaggers": [],
"votes": {"up_count": 4},
"child_count": 0,
"children": [],
}]
self.register_get_thread_response({
"id": self.thread_id,
"course_id": unicode(self.course.id),
"thread_type": "discussion",
"children": source_comments,
"resp_total": 100,
})
self.register_get_user_response(self.user, upvoted_ids=["test_comment"])
self.create_profile_image(self.user, get_profile_image_storage())
response = self.client.get(self.url, {"thread_id": self.thread_id, "requested_fields": "profile_image"})
self.assertEqual(response.status_code, 200)
response_comments = json.loads(response.content)['results']
for response_comment in response_comments:
expected_profile_data = self.get_expected_user_profile(response_comment['author'])
response_users = response_comment['users']
self.assertEqual(expected_profile_data, response_users[response_comment['author']])
def test_profile_image_requested_field_endorsed_comments(self):
"""
Tests all comments have user profile image details for both author and endorser
if called in requested_fields for endorsed threads
"""
endorser_user = UserFactory.create(password=self.password)
# Ensure that parental controls don't apply to this user
endorser_user.profile.year_of_birth = 1970
endorser_user.profile.save()
self.register_get_user_response(self.user)
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [make_minimal_cs_comment({
"id": "endorsed_comment",
"user_id": self.user.id,
"username": self.user.username,
"endorsed": True,
"endorsement": {"user_id": endorser_user.id, "time": "2016-05-10T08:51:28Z"},
})],
"non_endorsed_responses": [make_minimal_cs_comment({
"id": "non_endorsed_comment",
"user_id": self.user.id,
"username": self.user.username,
})],
"non_endorsed_resp_total": 1,
})
self.register_get_thread_response(thread)
self.create_profile_image(self.user, get_profile_image_storage())
self.create_profile_image(endorser_user, get_profile_image_storage())
response = self.client.get(self.url, {
"thread_id": thread["id"],
"endorsed": True,
"requested_fields": "profile_image",
})
self.assertEqual(response.status_code, 200)
response_comments = json.loads(response.content)['results']
for response_comment in response_comments:
expected_author_profile_data = self.get_expected_user_profile(response_comment['author'])
expected_endorser_profile_data = self.get_expected_user_profile(response_comment['endorsed_by'])
response_users = response_comment['users']
self.assertEqual(expected_author_profile_data, response_users[response_comment['author']])
self.assertEqual(expected_endorser_profile_data, response_users[response_comment['endorsed_by']])
@httpretty.activate
@disable_signal(api, 'comment_deleted')
......@@ -1484,7 +1658,7 @@ class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTes
@httpretty.activate
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
class ThreadViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
class ThreadViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase, ProfileImageTestMixin):
"""Tests for ThreadViewSet Retrieve"""
def setUp(self):
super(ThreadViewSetRetrieveTest, self).setUp()
......@@ -1545,10 +1719,29 @@ class ThreadViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase)
response = self.client.get(self.url)
self.assertEqual(response.status_code, 404)
def test_profile_image_requested_field(self):
"""
Tests thread has user profile image details if called in requested_fields
"""
self.register_get_user_response(self.user)
cs_thread = make_minimal_cs_thread({
"id": self.thread_id,
"course_id": unicode(self.course.id),
"username": self.user.username,
"user_id": str(self.user.id),
})
self.register_get_thread_response(cs_thread)
self.create_profile_image(self.user, get_profile_image_storage())
response = self.client.get(self.url, {"requested_fields": "profile_image"})
self.assertEqual(response.status_code, 200)
expected_profile_data = self.get_expected_user_profile(self.user.username)
response_users = json.loads(response.content)['users']
self.assertEqual(expected_profile_data, response_users[self.user.username])
@httpretty.activate
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
class CommentViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
class CommentViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase, ProfileImageTestMixin):
"""Tests for CommentViewSet Retrieve"""
def setUp(self):
super(CommentViewSetRetrieveTest, self).setUp()
......@@ -1642,3 +1835,28 @@ class CommentViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase
404,
{"developer_message": "Page not found (No results on this page)."}
)
def test_profile_image_requested_field(self):
"""
Tests all comments retrieved have user profile image details if called in requested_fields
"""
self.register_get_user_response(self.user)
cs_comment_child = self.make_comment_data('test_child_comment', self.comment_id, children=[])
cs_comment = self.make_comment_data(self.comment_id, None, [cs_comment_child])
cs_thread = make_minimal_cs_thread({
'id': self.thread_id,
'course_id': unicode(self.course.id),
'children': [cs_comment],
})
self.register_get_thread_response(cs_thread)
self.register_get_comment_response(cs_comment)
self.create_profile_image(self.user, get_profile_image_storage())
response = self.client.get(self.url, {'requested_fields': 'profile_image'})
self.assertEqual(response.status_code, 200)
response_comments = json.loads(response.content)['results']
for response_comment in response_comments:
expected_profile_data = self.get_expected_user_profile(response_comment['author'])
response_users = response_comment['users']
self.assertEqual(expected_profile_data, response_users[response_comment['author']])
"""
Discussion API test utilities
"""
from contextlib import closing
from datetime import datetime
import json
import re
import hashlib
import httpretty
from pytz import UTC
from PIL import Image
from openedx.core.djangoapps.profile_images.images import create_profile_images
from openedx.core.djangoapps.profile_images.tests.helpers import make_image_file
from openedx.core.djangoapps.user_api.accounts.image_helpers import get_profile_image_names, set_has_profile_image
def _get_thread_callback(thread_data):
......@@ -392,3 +401,56 @@ def make_paginated_api_response(results=None, count=0, num_pages=0, next_link=No
},
"results": results or []
}
class ProfileImageTestMixin(object):
"""
Mixin with utility methods for user profile image
"""
TEST_PROFILE_IMAGE_UPLOADED_AT = datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC)
def create_profile_image(self, user, storage):
"""
Creates profile image for user and checks that created image exists in storage
"""
with make_image_file() as image_file:
create_profile_images(image_file, get_profile_image_names(user.username))
self.check_images(user, storage)
set_has_profile_image(user.username, True, self.TEST_PROFILE_IMAGE_UPLOADED_AT)
def check_images(self, user, storage, exist=True):
"""
If exist is True, make sure the images physically exist in storage
with correct sizes and formats.
If exist is False, make sure none of the images exist.
"""
for size, name in get_profile_image_names(user.username).items():
if exist:
self.assertTrue(storage.exists(name))
with closing(Image.open(storage.path(name))) as img:
self.assertEqual(img.size, (size, size))
self.assertEqual(img.format, 'JPEG')
else:
self.assertFalse(storage.exists(name))
def get_expected_user_profile(self, username):
"""
Returns the expected user profile data for a given username
"""
url = 'http://example-storage.com/profile-images/{filename}_{{size}}.jpg?v={timestamp}'.format(
filename=hashlib.md5('secret' + username).hexdigest(),
timestamp=self.TEST_PROFILE_IMAGE_UPLOADED_AT.strftime("%s")
)
return {
'profile': {
'image': {
'has_image': True,
'image_url_full': url.format(size=500),
'image_url_large': url.format(size=120),
'image_url_medium': url.format(size=50),
'image_url_small': url.format(size=30),
}
}
}
......@@ -26,7 +26,7 @@ from discussion_api.api import (
update_comment,
update_thread,
)
from discussion_api.forms import CommentListGetForm, ThreadListGetForm, _PaginationForm
from discussion_api.forms import CommentListGetForm, ThreadListGetForm, CommentGetForm
from openedx.core.lib.api.parsers import MergePatchParser
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin, view_auth_classes
......@@ -118,7 +118,7 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
GET /api/discussion/v1/threads/?course_id=ExampleX/Demo/2015
GET /api/discussion/v1/threads/thread_id
GET /api/discussion/v1/threads/{thread_id}
POST /api/discussion/v1/threads
{
......@@ -164,10 +164,19 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
* view: "unread" for threads the requesting user has not read, or
"unanswered" for question threads with no marked answer. Only one
can be selected.
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
The topic_id, text_search, and following parameters are mutually
exclusive (i.e. only one may be specified in a request)
**GET Thread Parameters**:
* thread_id (required): The id of the thread
* requested_fields (optional parameter): (list) Indicates which additional
fields to return for each thread. (supports 'profile_image')
**POST Parameters**:
* course_id (required): The course to create the thread in
......@@ -279,13 +288,15 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["view"],
form.cleaned_data["order_by"],
form.cleaned_data["order_direction"],
form.cleaned_data["requested_fields"]
)
def retrieve(self, request, thread_id=None):
"""
Implements the GET method for thread ID
"""
return Response(get_thread(request, thread_id))
requested_fields = request.GET.get('requested_fields')
return Response(get_thread(request, thread_id, requested_fields))
def create(self, request):
"""
......@@ -351,6 +362,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100)
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
**GET Child Comment List Parameters**:
* comment_id (required): The comment to retrieve child comments for
......@@ -359,6 +373,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100)
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
**POST Parameters**:
......@@ -453,21 +470,23 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["thread_id"],
form.cleaned_data["endorsed"],
form.cleaned_data["page"],
form.cleaned_data["page_size"]
form.cleaned_data["page_size"],
form.cleaned_data["requested_fields"],
)
def retrieve(self, request, comment_id=None):
"""
Implements the GET method for comments against response ID
"""
form = _PaginationForm(request.GET)
form = CommentGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
return get_response_comments(
request,
comment_id,
form.cleaned_data["page"],
form.cleaned_data["page_size"]
form.cleaned_data["page_size"],
form.cleaned_data["requested_fields"],
)
def create(self, request):
......
......@@ -66,7 +66,7 @@ def learner_profile_context(request, profile_username, user_is_staff):
own_profile = (logged_in_user.username == profile_username)
account_settings_data = get_account_settings(request, profile_username)
account_settings_data = get_account_settings(request, [profile_username])[0]
preferences_data = get_user_preferences(profile_user, profile_username)
......
......@@ -1627,7 +1627,7 @@ class TestSubmitPhotosForVerification(TestCase):
"""
request = RequestFactory().get('/url')
request.user = self.user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(account_settings['name'], full_name)
def _get_post_data(self):
......
......@@ -41,7 +41,7 @@ visible_fields = _visible_fields
@intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError])
def get_account_settings(request, username=None, configuration=None, view=None):
def get_account_settings(request, usernames=None, configuration=None, view=None):
"""Returns account information for a user serialized as JSON.
Note:
......@@ -52,8 +52,8 @@ def get_account_settings(request, username=None, configuration=None, view=None):
request (Request): The request object with account information about the requesting user.
Only the user with username `username` or users with "is_staff" privileges can get full
account information. Other users will get the account fields that the user has elected to share.
username (str): Optional username for the desired account information. If not specified,
`request.user.username` is assumed.
usernames (list): Optional list of usernames for the desired account information. If not
specified, `request.user.username` is assumed.
configuration (dict): an optional configuration specifying which fields in the account
can be shared, and the default visibility settings. If not present, the setting value with
key ACCOUNT_VISIBILITY_CONFIGURATION is used.
......@@ -62,7 +62,7 @@ def get_account_settings(request, username=None, configuration=None, view=None):
"shared", only shared account information will be returned, regardless of `request.user`.
Returns:
A dict containing account fields.
A list of users account details.
Raises:
UserNotFound: no user with username `username` exists (or `request.user.username` if
......@@ -70,27 +70,27 @@ def get_account_settings(request, username=None, configuration=None, view=None):
UserAPIInternalError: the operation failed due to an unexpected error.
"""
requesting_user = request.user
usernames = usernames or [requesting_user.username]
if username is None:
username = requesting_user.username
try:
existing_user = User.objects.select_related('profile').get(username=username)
except ObjectDoesNotExist:
requested_users = User.objects.select_related('profile').filter(username__in=usernames)
if not requested_users:
raise UserNotFound()
has_full_access = requesting_user.username == username or requesting_user.is_staff
if has_full_access and view != 'shared':
admin_fields = settings.ACCOUNT_VISIBILITY_CONFIGURATION.get('admin_fields')
else:
admin_fields = None
return UserReadOnlySerializer(
existing_user,
configuration=configuration,
custom_fields=admin_fields,
context={'request': request}
).data
serialized_users = []
for user in requested_users:
has_full_access = requesting_user.is_staff or requesting_user.username == user.username
if has_full_access and view != 'shared':
admin_fields = settings.ACCOUNT_VISIBILITY_CONFIGURATION.get('admin_fields')
else:
admin_fields = None
serialized_users.append(UserReadOnlySerializer(
user,
configuration=configuration,
custom_fields=admin_fields,
context={'request': request}
).data)
return serialized_users
@intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError])
......
......@@ -57,13 +57,13 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
def test_get_username_provided(self):
"""Test the difference in behavior when a username is supplied to get_account_settings."""
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual(self.user.username, account_settings["username"])
account_settings = get_account_settings(self.default_request, username=self.user.username)
account_settings = get_account_settings(self.default_request, usernames=[self.user.username])[0]
self.assertEqual(self.user.username, account_settings["username"])
account_settings = get_account_settings(self.default_request, username=self.different_user.username)
account_settings = get_account_settings(self.default_request, usernames=[self.different_user.username])[0]
self.assertEqual(self.different_user.username, account_settings["username"])
def test_get_configuration_provided(self):
......@@ -81,20 +81,20 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
}
# With default configuration settings, email is not shared with other (non-staff) users.
account_settings = get_account_settings(self.default_request, self.different_user.username)
account_settings = get_account_settings(self.default_request, [self.different_user.username])[0]
self.assertNotIn("email", account_settings)
account_settings = get_account_settings(
self.default_request,
self.different_user.username,
[self.different_user.username],
configuration=config,
)
)[0]
self.assertEqual(self.different_user.email, account_settings["email"])
def test_get_user_not_found(self):
"""Test that UserNotFound is thrown if there is no user with username."""
with self.assertRaises(UserNotFound):
get_account_settings(self.default_request, username="does_not_exist")
get_account_settings(self.default_request, usernames=["does_not_exist"])
self.user.username = "does_not_exist"
request = self.request_factory.get("/api/user/v1/accounts/")
......@@ -105,11 +105,11 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
def test_update_username_provided(self):
"""Test the difference in behavior when a username is supplied to update_account_settings."""
update_account_settings(self.user, {"name": "Mickey Mouse"})
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Mickey Mouse", account_settings["name"])
update_account_settings(self.user, {"name": "Donald Duck"}, username=self.user.username)
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Donald Duck", account_settings["name"])
with self.assertRaises(UserNotAuthorized):
......@@ -189,7 +189,7 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
self.assertIn("Error thrown from do_email_change_request", context_manager.exception.developer_message)
# Verify that the name change happened, even though the attempt to send the email failed.
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Mickey Mouse", account_settings["name"])
@patch('openedx.core.djangoapps.user_api.accounts.serializers.AccountUserSerializer.save')
......@@ -255,7 +255,7 @@ class AccountSettingsOnCreationTest(TestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get("/api/user/v1/accounts/")
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
# Expect a date joined field but remove it to simplify the following comparison
self.assertIsNotNone(account_settings['date_joined'])
......@@ -341,14 +341,14 @@ class AccountCreationActivationAndPasswordChangeTest(TestCase):
request = RequestFactory().get("/api/user/v1/accounts/")
request.user = user
account = get_account_settings(request)
account = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account["username"])
self.assertEqual(self.EMAIL, account["email"])
self.assertFalse(account["is_active"])
# Activate the account and verify that it is now active
activate_account(activation_key)
account = get_account_settings(request)
account = get_account_settings(request)[0]
self.assertTrue(account['is_active'])
def test_create_account_duplicate_username(self):
......
......@@ -10,6 +10,7 @@ from rest_framework import permissions
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
from openedx.core.lib.api.authentication import (
SessionAuthenticationAllowInactiveUser,
......@@ -20,7 +21,7 @@ from .api import get_account_settings, update_account_settings
from ..errors import UserNotFound, UserNotAuthorized, AccountUpdateError, AccountValidationError
class AccountView(APIView):
class AccountViewSet(ViewSet):
"""
**Use Cases**
......@@ -29,6 +30,7 @@ class AccountView(APIView):
**Example Requests**
GET /api/user/v1/accounts?usernames={username1,username2}[?view=shared]
GET /api/user/v1/accounts/{username}/[?view=shared]
PATCH /api/user/v1/accounts/{username}/{"key":"value"} "application/merge-patch+json"
......@@ -146,19 +148,34 @@ class AccountView(APIView):
permission_classes = (permissions.IsAuthenticated,)
parser_classes = (MergePatchParser,)
def get(self, request, username):
def list(self, request):
"""
GET /api/user/v1/accounts/{username}/
GET /api/user/v1/accounts?username={username1,username2}
"""
usernames = request.GET.get('username')
try:
if usernames:
usernames = usernames.strip(',').split(',')
account_settings = get_account_settings(
request, username, view=request.query_params.get('view'))
request, usernames, view=request.query_params.get('view'))
except UserNotFound:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
return Response(account_settings)
def patch(self, request, username):
def retrieve(self, request, username):
"""
GET /api/user/v1/accounts/{username}/
"""
try:
account_settings = get_account_settings(
request, [username], view=request.query_params.get('view'))
except UserNotFound:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
return Response(account_settings[0])
def partial_update(self, request, username):
"""
PATCH /api/user/v1/accounts/{username}/
......@@ -169,7 +186,7 @@ class AccountView(APIView):
try:
with transaction.atomic():
update_account_settings(request.user, request.data, username=username)
account_settings = get_account_settings(request, username)
account_settings = get_account_settings(request, [username])[0]
except UserNotAuthorized:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
except UserNotFound:
......
......@@ -1366,7 +1366,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url')
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account_settings["username"])
self.assertEqual(self.EMAIL, account_settings["email"])
......@@ -1406,7 +1406,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url')
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(account_settings["level_of_education"], self.EDUCATION)
self.assertEqual(account_settings["mailing_address"], self.ADDRESS)
......@@ -1440,7 +1440,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url')
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account_settings["username"])
self.assertEqual(self.EMAIL, account_settings["email"])
......
......@@ -6,16 +6,23 @@ from django.conf import settings
from django.conf.urls import patterns, url
from ..profile_images.views import ProfileImageView
from .accounts.views import AccountView
from .accounts.views import AccountViewSet
from .preferences.views import PreferencesView, PreferencesDetailView
ACCOUNT_LIST = AccountViewSet.as_view({
'get': 'list',
})
ACCOUNT_DETAIL = AccountViewSet.as_view({
'get': 'retrieve',
'patch': 'partial_update',
})
urlpatterns = patterns(
'',
url(
r'^v1/accounts/{}$'.format(settings.USERNAME_PATTERN),
AccountView.as_view(),
name="accounts_api"
),
url(r'^v1/accounts/{}$'.format(settings.USERNAME_PATTERN), ACCOUNT_DETAIL, name='accounts_api'),
url(r'^v1/accounts$', ACCOUNT_LIST, name='accounts_detail_api'),
url(
r'^v1/accounts/{}/image$'.format(settings.USERNAME_PATTERN),
ProfileImageView.as_view(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment