Commit f19af520 by wajeeha-khalid

MA-2050: add optional user profile details in GET thread & comment

parent c681c567
......@@ -51,6 +51,7 @@ class ThreadListGetForm(_PaginationForm):
choices=[(choice, choice) for choice in ["asc", "desc"]],
required=False
)
requested_fields = MultiValueField(required=False)
def clean_order_by(self):
"""Return a default choice"""
......@@ -106,6 +107,7 @@ class CommentListGetForm(_PaginationForm):
"""
thread_id = CharField()
endorsed = ExtendedNullBooleanField(required=False)
requested_fields = MultiValueField(required=False)
class CommentActionsForm(Form):
......@@ -115,3 +117,10 @@ class CommentActionsForm(Form):
"""
voted = BooleanField(required=False)
abuse_flagged = BooleanField(required=False)
class CommentGetForm(_PaginationForm):
"""
A form to validate query parameters in the comment retrieval endpoint
"""
requested_fields = MultiValueField(required=False)
......@@ -40,7 +40,8 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
make_minimal_cs_thread,
make_paginated_api_response
make_paginated_api_response,
ProfileImageTestMixin,
)
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
......@@ -1148,7 +1149,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_basic_query_params(self):
self.get_comment_list(
self.make_minimal_cs_thread({
"children": [make_minimal_cs_comment()],
"children": [make_minimal_cs_comment({"username": self.user.username})],
"resp_total": 71
}),
page=6,
......@@ -1254,8 +1255,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_question_content(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment"})],
"non_endorsed_responses": [make_minimal_cs_comment({"id": "non_endorsed_comment"})],
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment", "username": self.user.username})],
"non_endorsed_responses": [make_minimal_cs_comment({
"id": "non_endorsed_comment", "username": self.user.username
})],
"non_endorsed_resp_total": 1,
})
......@@ -1274,7 +1277,8 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
"anonymous": True,
"children": [
make_minimal_cs_comment({
"endorsement": {"user_id": str(self.author.id), "time": "2015-05-18T12:34:56Z"}
"username": self.user.username,
"endorsement": {"user_id": str(self.author.id), "time": "2015-05-18T12:34:56Z"},
})
]
})
......@@ -1301,7 +1305,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
# number of responses is unrealistic but convenient for this test
thread = self.make_minimal_cs_thread({
"thread_type": thread_type,
response_field: [make_minimal_cs_comment()],
response_field: [make_minimal_cs_comment({"username": self.user.username})],
response_total_field: 5,
})
......@@ -1337,9 +1341,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
def test_question_endorsed_pagination(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [
make_minimal_cs_comment({"id": "comment_{}".format(i)}) for i in range(10)
]
"endorsed_responses": [make_minimal_cs_comment({
"id": "comment_{}".format(i),
"username": self.user.username
}) for i in range(10)]
})
def assert_page_correct(page, page_size, expected_start, expected_stop, expected_next, expected_prev):
......@@ -1553,7 +1558,7 @@ class CreateThreadTest(
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
self.register_post_thread_response({})
self.register_post_thread_response({"username": self.user.username})
data = self.minimal_data.copy()
data["course_id"] = unicode(cohort_course.id)
if data_group_state == "group_is_none":
......@@ -1582,7 +1587,7 @@ class CreateThreadTest(
self.fail("Unexpected validation error: {}".format(ex))
def test_following(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_subscription_response(self.user)
data = self.minimal_data.copy()
data["following"] = "True"
......@@ -1600,7 +1605,7 @@ class CreateThreadTest(
)
def test_voted(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_thread_votes_response("test_id")
data = self.minimal_data.copy()
data["voted"] = "True"
......@@ -1616,7 +1621,7 @@ class CreateThreadTest(
)
def test_abuse_flagged(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
self.register_thread_flag_response("test_id")
data = self.minimal_data.copy()
data["abuse_flagged"] = "True"
......@@ -1802,7 +1807,7 @@ class CreateCommentTest(
"user_id": str(self.user.id) if is_thread_author else str(self.user.id + 1),
})
)
self.register_post_comment_response({}, "test_thread")
self.register_post_comment_response({"username": self.user.username}, "test_thread")
data = self.minimal_data.copy()
data["endorsed"] = True
expected_error = (
......@@ -1817,7 +1822,7 @@ class CreateCommentTest(
self.assertTrue(expected_error)
def test_voted(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread")
self.register_post_comment_response({"id": "test_comment", "username": self.user.username}, "test_thread")
self.register_comment_votes_response("test_comment")
data = self.minimal_data.copy()
data["voted"] = "True"
......@@ -1833,7 +1838,7 @@ class CreateCommentTest(
)
def test_abuse_flagged(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread")
self.register_post_comment_response({"id": "test_comment", "username": self.user.username}, "test_thread")
self.register_comment_flag_response("test_comment")
data = self.minimal_data.copy()
data["abuse_flagged"] = "True"
......@@ -1906,7 +1911,7 @@ class CreateCommentTest(
cohort.id + 1
),
}))
self.register_post_comment_response({}, thread_id="cohort_thread")
self.register_post_comment_response({"username": self.user.username}, thread_id="cohort_thread")
data = self.minimal_data.copy()
data["thread_id"] = "cohort_thread"
expected_error = (
......
......@@ -69,6 +69,7 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"view": "",
"order_by": "last_activity_at",
"order_direction": "desc",
"requested_fields": set(),
}
)
......@@ -154,6 +155,14 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
self.form_data[field] = value
self.assert_field_value(field, value)
def test_requested_fields(self):
self.form_data["requested_fields"] = "profile_image"
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["requested_fields"],
{"profile_image"},
)
@ddt.ddt
class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
......@@ -177,7 +186,8 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"thread_id": "deadbeef",
"endorsed": False,
"page": 2,
"page_size": 13
"page_size": 13,
"requested_fields": set(),
}
)
......@@ -202,3 +212,11 @@ class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
def test_invalid_endorsed(self):
self.form_data["endorsed"] = "invalid-boolean"
self.assert_error("endorsed", "Invalid Boolean Value.")
def test_requested_fields(self):
self.form_data["requested_fields"] = {"profile_image"}
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data["requested_fields"],
{"profile_image"},
)
......@@ -16,6 +16,7 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_thread,
make_minimal_cs_comment,
ProfileImageTestMixin,
)
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
......@@ -237,7 +238,7 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe
thread_data = self.make_cs_content({})
del thread_data["pinned"]
self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"]))
serialized = self.serialize(thread_data)
self.assertEqual(serialized["pinned"], False)
def test_group(self):
......@@ -255,14 +256,14 @@ class ThreadSerializerSerializationTest(SerializerTestMixin, SharedModuleStoreTe
def test_response_count(self):
thread_data = self.make_cs_content({"resp_total": 2})
self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"]))
serialized = self.serialize(thread_data)
self.assertEqual(serialized["response_count"], 2)
def test_response_count_missing(self):
thread_data = self.make_cs_content({})
del thread_data["resp_total"]
self.register_get_thread_response(thread_data)
serialized = self.serialize(Thread(id=thread_data["id"]))
serialized = self.serialize(thread_data)
self.assertNotIn("response_count", serialized)
......@@ -459,6 +460,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
"title": "Original Title",
"body": "Original body",
"user_id": str(self.user.id),
"username": self.user.username,
"read": "False",
"endorsed": "False"
}))
......@@ -480,7 +482,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
return serializer.data
def test_create_minimal(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
saved = self.save_and_reserialize(self.minimal_data)
self.assertEqual(
urlparse(httpretty.last_request().path).path,
......@@ -500,7 +502,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
self.assertEqual(saved["id"], "test_id")
def test_create_all_fields(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
data = self.minimal_data.copy()
data["group_id"] = 42
self.save_and_reserialize(data)
......@@ -540,7 +542,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
)
def test_create_type(self):
self.register_post_thread_response({"id": "test_id"})
self.register_post_thread_response({"id": "test_id", "username": self.user.username})
data = self.minimal_data.copy()
data["type"] = "question"
self.save_and_reserialize(data)
......@@ -655,6 +657,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
"thread_id": "existing_thread",
"body": "Original body",
"user_id": str(self.user.id),
"username": self.user.username,
"course_id": unicode(self.course.id),
}))
......@@ -685,7 +688,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["parent_id"] = parent_id
self.register_get_comment_response({"thread_id": "test_thread", "id": parent_id})
self.register_post_comment_response(
{"id": "test_comment"},
{"id": "test_comment", "username": self.user.username},
thread_id="test_thread",
parent_id=parent_id
)
......@@ -712,7 +715,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["endorsed"] = True
self.register_get_comment_response({"thread_id": "test_thread", "id": "test_parent"})
self.register_post_comment_response(
{"id": "test_comment"},
{"id": "test_comment", "username": self.user.username},
thread_id="test_thread",
parent_id="test_parent"
)
......@@ -807,7 +810,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
def test_create_endorsed(self):
# TODO: The comments service doesn't populate the endorsement field on
# comment creation, so this is sadly realistic
self.register_post_comment_response({}, thread_id="test_thread")
self.register_post_comment_response({"username": self.user.username}, thread_id="test_thread")
data = self.minimal_data.copy()
data["endorsed"] = True
saved = self.save_and_reserialize(data)
......
"""
Discussion API test utilities
"""
from contextlib import closing
from datetime import datetime
import json
import re
import hashlib
import httpretty
from pytz import UTC
from PIL import Image
from openedx.core.djangoapps.profile_images.images import create_profile_images
from openedx.core.djangoapps.profile_images.tests.helpers import make_image_file
from openedx.core.djangoapps.user_api.accounts.image_helpers import get_profile_image_names, set_has_profile_image
def _get_thread_callback(thread_data):
......@@ -392,3 +401,56 @@ def make_paginated_api_response(results=None, count=0, num_pages=0, next_link=No
},
"results": results or []
}
class ProfileImageTestMixin(object):
"""
Mixin with utility methods for user profile image
"""
TEST_PROFILE_IMAGE_UPLOADED_AT = datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC)
def create_profile_image(self, user, storage):
"""
Creates profile image for user and checks that created image exists in storage
"""
with make_image_file() as image_file:
create_profile_images(image_file, get_profile_image_names(user.username))
self.check_images(user, storage)
set_has_profile_image(user.username, True, self.TEST_PROFILE_IMAGE_UPLOADED_AT)
def check_images(self, user, storage, exist=True):
"""
If exist is True, make sure the images physically exist in storage
with correct sizes and formats.
If exist is False, make sure none of the images exist.
"""
for size, name in get_profile_image_names(user.username).items():
if exist:
self.assertTrue(storage.exists(name))
with closing(Image.open(storage.path(name))) as img:
self.assertEqual(img.size, (size, size))
self.assertEqual(img.format, 'JPEG')
else:
self.assertFalse(storage.exists(name))
def get_expected_user_profile(self, username):
"""
Returns the expected user profile data for a given username
"""
url = 'http://example-storage.com/profile-images/{filename}_{{size}}.jpg?v={timestamp}'.format(
filename=hashlib.md5('secret' + username).hexdigest(),
timestamp=self.TEST_PROFILE_IMAGE_UPLOADED_AT.strftime("%s")
)
return {
'profile': {
'image': {
'has_image': True,
'image_url_full': url.format(size=500),
'image_url_large': url.format(size=120),
'image_url_medium': url.format(size=50),
'image_url_small': url.format(size=30),
}
}
}
......@@ -26,7 +26,7 @@ from discussion_api.api import (
update_comment,
update_thread,
)
from discussion_api.forms import CommentListGetForm, ThreadListGetForm, _PaginationForm
from discussion_api.forms import CommentListGetForm, ThreadListGetForm, CommentGetForm
from openedx.core.lib.api.parsers import MergePatchParser
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin, view_auth_classes
......@@ -118,7 +118,7 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
GET /api/discussion/v1/threads/?course_id=ExampleX/Demo/2015
GET /api/discussion/v1/threads/thread_id
GET /api/discussion/v1/threads/{thread_id}
POST /api/discussion/v1/threads
{
......@@ -164,10 +164,19 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
* view: "unread" for threads the requesting user has not read, or
"unanswered" for question threads with no marked answer. Only one
can be selected.
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
The topic_id, text_search, and following parameters are mutually
exclusive (i.e. only one may be specified in a request)
**GET Thread Parameters**:
* thread_id (required): The id of the thread
* requested_fields (optional parameter): (list) Indicates which additional
fields to return for each thread. (supports 'profile_image')
**POST Parameters**:
* course_id (required): The course to create the thread in
......@@ -279,13 +288,15 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["view"],
form.cleaned_data["order_by"],
form.cleaned_data["order_direction"],
form.cleaned_data["requested_fields"]
)
def retrieve(self, request, thread_id=None):
"""
Implements the GET method for thread ID
"""
return Response(get_thread(request, thread_id))
requested_fields = request.GET.get('requested_fields')
return Response(get_thread(request, thread_id, requested_fields))
def create(self, request):
"""
......@@ -351,6 +362,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100)
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
**GET Child Comment List Parameters**:
* comment_id (required): The comment to retrieve child comments for
......@@ -359,6 +373,9 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100)
* requested_fields: (list) Indicates which additional fields to return
for each thread. (supports 'profile_image')
**POST Parameters**:
......@@ -453,21 +470,23 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["thread_id"],
form.cleaned_data["endorsed"],
form.cleaned_data["page"],
form.cleaned_data["page_size"]
form.cleaned_data["page_size"],
form.cleaned_data["requested_fields"],
)
def retrieve(self, request, comment_id=None):
"""
Implements the GET method for comments against response ID
"""
form = _PaginationForm(request.GET)
form = CommentGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
return get_response_comments(
request,
comment_id,
form.cleaned_data["page"],
form.cleaned_data["page_size"]
form.cleaned_data["page_size"],
form.cleaned_data["requested_fields"],
)
def create(self, request):
......
......@@ -66,7 +66,7 @@ def learner_profile_context(request, profile_username, user_is_staff):
own_profile = (logged_in_user.username == profile_username)
account_settings_data = get_account_settings(request, profile_username)
account_settings_data = get_account_settings(request, [profile_username])[0]
preferences_data = get_user_preferences(profile_user, profile_username)
......
......@@ -1627,7 +1627,7 @@ class TestSubmitPhotosForVerification(TestCase):
"""
request = RequestFactory().get('/url')
request.user = self.user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(account_settings['name'], full_name)
def _get_post_data(self):
......
......@@ -41,7 +41,7 @@ visible_fields = _visible_fields
@intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError])
def get_account_settings(request, username=None, configuration=None, view=None):
def get_account_settings(request, usernames=None, configuration=None, view=None):
"""Returns account information for a user serialized as JSON.
Note:
......@@ -52,8 +52,8 @@ def get_account_settings(request, username=None, configuration=None, view=None):
request (Request): The request object with account information about the requesting user.
Only the user with username `username` or users with "is_staff" privileges can get full
account information. Other users will get the account fields that the user has elected to share.
username (str): Optional username for the desired account information. If not specified,
`request.user.username` is assumed.
usernames (list): Optional list of usernames for the desired account information. If not
specified, `request.user.username` is assumed.
configuration (dict): an optional configuration specifying which fields in the account
can be shared, and the default visibility settings. If not present, the setting value with
key ACCOUNT_VISIBILITY_CONFIGURATION is used.
......@@ -62,7 +62,7 @@ def get_account_settings(request, username=None, configuration=None, view=None):
"shared", only shared account information will be returned, regardless of `request.user`.
Returns:
A dict containing account fields.
A list of users account details.
Raises:
UserNotFound: no user with username `username` exists (or `request.user.username` if
......@@ -70,27 +70,27 @@ def get_account_settings(request, username=None, configuration=None, view=None):
UserAPIInternalError: the operation failed due to an unexpected error.
"""
requesting_user = request.user
usernames = usernames or [requesting_user.username]
if username is None:
username = requesting_user.username
try:
existing_user = User.objects.select_related('profile').get(username=username)
except ObjectDoesNotExist:
requested_users = User.objects.select_related('profile').filter(username__in=usernames)
if not requested_users:
raise UserNotFound()
has_full_access = requesting_user.username == username or requesting_user.is_staff
if has_full_access and view != 'shared':
admin_fields = settings.ACCOUNT_VISIBILITY_CONFIGURATION.get('admin_fields')
else:
admin_fields = None
return UserReadOnlySerializer(
existing_user,
configuration=configuration,
custom_fields=admin_fields,
context={'request': request}
).data
serialized_users = []
for user in requested_users:
has_full_access = requesting_user.is_staff or requesting_user.username == user.username
if has_full_access and view != 'shared':
admin_fields = settings.ACCOUNT_VISIBILITY_CONFIGURATION.get('admin_fields')
else:
admin_fields = None
serialized_users.append(UserReadOnlySerializer(
user,
configuration=configuration,
custom_fields=admin_fields,
context={'request': request}
).data)
return serialized_users
@intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError])
......
......@@ -57,13 +57,13 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
def test_get_username_provided(self):
"""Test the difference in behavior when a username is supplied to get_account_settings."""
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual(self.user.username, account_settings["username"])
account_settings = get_account_settings(self.default_request, username=self.user.username)
account_settings = get_account_settings(self.default_request, usernames=[self.user.username])[0]
self.assertEqual(self.user.username, account_settings["username"])
account_settings = get_account_settings(self.default_request, username=self.different_user.username)
account_settings = get_account_settings(self.default_request, usernames=[self.different_user.username])[0]
self.assertEqual(self.different_user.username, account_settings["username"])
def test_get_configuration_provided(self):
......@@ -81,20 +81,20 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
}
# With default configuration settings, email is not shared with other (non-staff) users.
account_settings = get_account_settings(self.default_request, self.different_user.username)
account_settings = get_account_settings(self.default_request, [self.different_user.username])[0]
self.assertNotIn("email", account_settings)
account_settings = get_account_settings(
self.default_request,
self.different_user.username,
[self.different_user.username],
configuration=config,
)
)[0]
self.assertEqual(self.different_user.email, account_settings["email"])
def test_get_user_not_found(self):
"""Test that UserNotFound is thrown if there is no user with username."""
with self.assertRaises(UserNotFound):
get_account_settings(self.default_request, username="does_not_exist")
get_account_settings(self.default_request, usernames=["does_not_exist"])
self.user.username = "does_not_exist"
request = self.request_factory.get("/api/user/v1/accounts/")
......@@ -105,11 +105,11 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
def test_update_username_provided(self):
"""Test the difference in behavior when a username is supplied to update_account_settings."""
update_account_settings(self.user, {"name": "Mickey Mouse"})
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Mickey Mouse", account_settings["name"])
update_account_settings(self.user, {"name": "Donald Duck"}, username=self.user.username)
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Donald Duck", account_settings["name"])
with self.assertRaises(UserNotAuthorized):
......@@ -189,7 +189,7 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
self.assertIn("Error thrown from do_email_change_request", context_manager.exception.developer_message)
# Verify that the name change happened, even though the attempt to send the email failed.
account_settings = get_account_settings(self.default_request)
account_settings = get_account_settings(self.default_request)[0]
self.assertEqual("Mickey Mouse", account_settings["name"])
@patch('openedx.core.djangoapps.user_api.accounts.serializers.AccountUserSerializer.save')
......@@ -255,7 +255,7 @@ class AccountSettingsOnCreationTest(TestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get("/api/user/v1/accounts/")
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
# Expect a date joined field but remove it to simplify the following comparison
self.assertIsNotNone(account_settings['date_joined'])
......@@ -341,14 +341,14 @@ class AccountCreationActivationAndPasswordChangeTest(TestCase):
request = RequestFactory().get("/api/user/v1/accounts/")
request.user = user
account = get_account_settings(request)
account = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account["username"])
self.assertEqual(self.EMAIL, account["email"])
self.assertFalse(account["is_active"])
# Activate the account and verify that it is now active
activate_account(activation_key)
account = get_account_settings(request)
account = get_account_settings(request)[0]
self.assertTrue(account['is_active'])
def test_create_account_duplicate_username(self):
......
......@@ -10,6 +10,7 @@ from rest_framework import permissions
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
from openedx.core.lib.api.authentication import (
SessionAuthenticationAllowInactiveUser,
......@@ -20,7 +21,7 @@ from .api import get_account_settings, update_account_settings
from ..errors import UserNotFound, UserNotAuthorized, AccountUpdateError, AccountValidationError
class AccountView(APIView):
class AccountViewSet(ViewSet):
"""
**Use Cases**
......@@ -29,6 +30,7 @@ class AccountView(APIView):
**Example Requests**
GET /api/user/v1/accounts?usernames={username1,username2}[?view=shared]
GET /api/user/v1/accounts/{username}/[?view=shared]
PATCH /api/user/v1/accounts/{username}/{"key":"value"} "application/merge-patch+json"
......@@ -146,19 +148,34 @@ class AccountView(APIView):
permission_classes = (permissions.IsAuthenticated,)
parser_classes = (MergePatchParser,)
def get(self, request, username):
def list(self, request):
"""
GET /api/user/v1/accounts/{username}/
GET /api/user/v1/accounts?username={username1,username2}
"""
usernames = request.GET.get('username')
try:
if usernames:
usernames = usernames.strip(',').split(',')
account_settings = get_account_settings(
request, username, view=request.query_params.get('view'))
request, usernames, view=request.query_params.get('view'))
except UserNotFound:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
return Response(account_settings)
def patch(self, request, username):
def retrieve(self, request, username):
"""
GET /api/user/v1/accounts/{username}/
"""
try:
account_settings = get_account_settings(
request, [username], view=request.query_params.get('view'))
except UserNotFound:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
return Response(account_settings[0])
def partial_update(self, request, username):
"""
PATCH /api/user/v1/accounts/{username}/
......@@ -169,7 +186,7 @@ class AccountView(APIView):
try:
with transaction.atomic():
update_account_settings(request.user, request.data, username=username)
account_settings = get_account_settings(request, username)
account_settings = get_account_settings(request, [username])[0]
except UserNotAuthorized:
return Response(status=status.HTTP_403_FORBIDDEN if request.user.is_staff else status.HTTP_404_NOT_FOUND)
except UserNotFound:
......
......@@ -1366,7 +1366,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url')
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account_settings["username"])
self.assertEqual(self.EMAIL, account_settings["email"])
......@@ -1406,7 +1406,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url')
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(account_settings["level_of_education"], self.EDUCATION)
self.assertEqual(account_settings["mailing_address"], self.ADDRESS)
......@@ -1440,7 +1440,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
user = User.objects.get(username=self.USERNAME)
request = RequestFactory().get('/url')
request.user = user
account_settings = get_account_settings(request)
account_settings = get_account_settings(request)[0]
self.assertEqual(self.USERNAME, account_settings["username"])
self.assertEqual(self.EMAIL, account_settings["email"])
......
......@@ -6,16 +6,23 @@ from django.conf import settings
from django.conf.urls import patterns, url
from ..profile_images.views import ProfileImageView
from .accounts.views import AccountView
from .accounts.views import AccountViewSet
from .preferences.views import PreferencesView, PreferencesDetailView
ACCOUNT_LIST = AccountViewSet.as_view({
'get': 'list',
})
ACCOUNT_DETAIL = AccountViewSet.as_view({
'get': 'retrieve',
'patch': 'partial_update',
})
urlpatterns = patterns(
'',
url(
r'^v1/accounts/{}$'.format(settings.USERNAME_PATTERN),
AccountView.as_view(),
name="accounts_api"
),
url(r'^v1/accounts/{}$'.format(settings.USERNAME_PATTERN), ACCOUNT_DETAIL, name='accounts_api'),
url(r'^v1/accounts$', ACCOUNT_LIST, name='accounts_detail_api'),
url(
r'^v1/accounts/{}/image$'.format(settings.USERNAME_PATTERN),
ProfileImageView.as_view(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment