Commit 107c4959 by wajeeha-khalid

Merge pull request #10008 from edx/jia/MA-1281

MA-1281 User Account: added account_privacy field in GET and PATCH endpoints
parents 034ac1c8 017477a8
...@@ -65,7 +65,8 @@ class MembershipSerializerTestCase(SerializerTestCase): ...@@ -65,7 +65,8 @@ class MembershipSerializerTestCase(SerializerTestCase):
'image_url_medium': 'http://testserver/static/default_50.png', 'image_url_medium': 'http://testserver/static/default_50.png',
'image_url_small': 'http://testserver/static/default_30.png', 'image_url_small': 'http://testserver/static/default_30.png',
'has_image': False 'has_image': False
} },
'account_privacy': None
}) })
self.assertNotIn('membership', data['team']) self.assertNotIn('membership', data['team'])
......
...@@ -129,7 +129,7 @@ class TestDashboard(SharedModuleStoreTestCase): ...@@ -129,7 +129,7 @@ class TestDashboard(SharedModuleStoreTestCase):
team.add_user(self.user) team.add_user(self.user)
# Check the query count on the dashboard again # Check the query count on the dashboard again
with self.assertNumQueries(19): with self.assertNumQueries(20):
self.client.get(self.teams_url) self.client.get(self.teams_url)
def test_bad_course_id(self): def test_bad_course_id(self):
......
...@@ -2544,12 +2544,14 @@ ACCOUNT_VISIBILITY_CONFIGURATION = { ...@@ -2544,12 +2544,14 @@ ACCOUNT_VISIBILITY_CONFIGURATION = {
'time_zone', 'time_zone',
'language_proficiencies', 'language_proficiencies',
'bio', 'bio',
'account_privacy',
], ],
# The list of account fields that are always public # The list of account fields that are always public
"public_fields": [ "public_fields": [
'username', 'username',
'profile_image', 'profile_image',
'account_privacy',
], ],
# The list of account fields that are visible only to staff and users viewing their own profiles # The list of account fields that are visible only to staff and users viewing their own profiles
...@@ -2569,6 +2571,7 @@ ACCOUNT_VISIBILITY_CONFIGURATION = { ...@@ -2569,6 +2571,7 @@ ACCOUNT_VISIBILITY_CONFIGURATION = {
"level_of_education", "level_of_education",
"mailing_address", "mailing_address",
"requires_parental_consent", "requires_parental_consent",
"account_privacy",
] ]
} }
......
...@@ -8,6 +8,8 @@ from pytz import UTC ...@@ -8,6 +8,8 @@ from pytz import UTC
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.conf import settings from django.conf import settings
from django.core.validators import validate_email, validate_slug, ValidationError from django.core.validators import validate_email, validate_slug, ValidationError
from openedx.core.djangoapps.user_api.preferences.api import update_user_preferences
from openedx.core.djangoapps.user_api.errors import PreferenceValidationError
from student.models import User, UserProfile, Registration from student.models import User, UserProfile, Registration
from student import views as student_views from student import views as student_views
...@@ -70,7 +72,7 @@ def get_account_settings(request, username=None, configuration=None, view=None): ...@@ -70,7 +72,7 @@ def get_account_settings(request, username=None, configuration=None, view=None):
username = requesting_user.username username = requesting_user.username
try: try:
existing_user = User.objects.get(username=username) existing_user = User.objects.select_related('profile').get(username=username)
except ObjectDoesNotExist: except ObjectDoesNotExist:
raise UserNotFound() raise UserNotFound()
...@@ -185,6 +187,13 @@ def update_account_settings(requesting_user, update, username=None): ...@@ -185,6 +187,13 @@ def update_account_settings(requesting_user, update, username=None):
for serializer in user_serializer, legacy_profile_serializer: for serializer in user_serializer, legacy_profile_serializer:
serializer.save() serializer.save()
# if any exception is raised for user preference (i.e. account_privacy), the entire transaction for user account
# patch is rolled back and the data is not saved
if 'account_privacy' in update:
update_user_preferences(
requesting_user, {'account_privacy': update["account_privacy"]}, existing_user
)
if "language_proficiencies" in update: if "language_proficiencies" in update:
new_language_proficiencies = update["language_proficiencies"] new_language_proficiencies = update["language_proficiencies"]
emit_setting_changed_event( emit_setting_changed_event(
...@@ -209,6 +218,8 @@ def update_account_settings(requesting_user, update, username=None): ...@@ -209,6 +218,8 @@ def update_account_settings(requesting_user, update, username=None):
existing_user_profile.set_meta(meta) existing_user_profile.set_meta(meta)
existing_user_profile.save() existing_user_profile.save()
except PreferenceValidationError as err:
raise AccountValidationError(err.preference_errors)
except Exception as err: except Exception as err:
raise AccountUpdateError( raise AccountUpdateError(
u"Error thrown when saving account updates: '{}'".format(err.message) u"Error thrown when saving account updates: '{}'".format(err.message)
......
...@@ -91,6 +91,7 @@ class UserReadOnlySerializer(serializers.Serializer): ...@@ -91,6 +91,7 @@ class UserReadOnlySerializer(serializers.Serializer):
"level_of_education": AccountLegacyProfileSerializer.convert_empty_to_None(profile.level_of_education), "level_of_education": AccountLegacyProfileSerializer.convert_empty_to_None(profile.level_of_education),
"mailing_address": profile.mailing_address, "mailing_address": profile.mailing_address,
"requires_parental_consent": profile.requires_parental_consent(), "requires_parental_consent": profile.requires_parental_consent(),
"account_privacy": UserPreference.get_value(user, 'account_privacy'),
} }
return self._filter_fields( return self._filter_fields(
......
...@@ -150,6 +150,9 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): ...@@ -150,6 +150,9 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
{"language_proficiencies": [{}]} {"language_proficiencies": [{}]}
) )
with self.assertRaises(AccountValidationError):
update_account_settings(self.user, {"account_privacy": ""})
def test_update_multiple_validation_errors(self): def test_update_multiple_validation_errors(self):
"""Test that all validation errors are built up and returned at once""" """Test that all validation errors are built up and returned at once"""
# Send a read-only error, serializer error, and email validation error. # Send a read-only error, serializer error, and email validation error.
...@@ -275,6 +278,7 @@ class AccountSettingsOnCreationTest(TestCase): ...@@ -275,6 +278,7 @@ class AccountSettingsOnCreationTest(TestCase):
}, },
'requires_parental_consent': True, 'requires_parental_consent': True,
'language_proficiencies': [], 'language_proficiencies': [],
'account_privacy': None
}) })
......
...@@ -16,6 +16,7 @@ from django.core.urlresolvers import reverse ...@@ -16,6 +16,7 @@ from django.core.urlresolvers import reverse
from django.test.testcases import TransactionTestCase from django.test.testcases import TransactionTestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from rest_framework.test import APITestCase, APIClient from rest_framework.test import APITestCase, APIClient
from openedx.core.djangoapps.user_api.models import UserPreference
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from student.models import UserProfile, LanguageProficiency, PendingEmailChange from student.models import UserProfile, LanguageProficiency, PendingEmailChange
...@@ -151,34 +152,36 @@ class TestAccountAPI(UserAPITestCase): ...@@ -151,34 +152,36 @@ class TestAccountAPI(UserAPITestCase):
} }
) )
def _verify_full_shareable_account_response(self, response): def _verify_full_shareable_account_response(self, response, account_privacy=None):
""" """
Verify that the shareable fields from the account are returned Verify that the shareable fields from the account are returned
""" """
data = response.data data = response.data
self.assertEqual(6, len(data)) self.assertEqual(7, len(data))
self.assertEqual(self.user.username, data["username"]) self.assertEqual(self.user.username, data["username"])
self.assertEqual("US", data["country"]) self.assertEqual("US", data["country"])
self._verify_profile_image_data(data, True) self._verify_profile_image_data(data, True)
self.assertIsNone(data["time_zone"]) self.assertIsNone(data["time_zone"])
self.assertEqual([{"code": "en"}], data["language_proficiencies"]) self.assertEqual([{"code": "en"}], data["language_proficiencies"])
self.assertEqual("Tired mother of twins", data["bio"]) self.assertEqual("Tired mother of twins", data["bio"])
self.assertEqual(account_privacy, data["account_privacy"])
def _verify_private_account_response(self, response, requires_parental_consent=False): def _verify_private_account_response(self, response, requires_parental_consent=False, account_privacy=None):
""" """
Verify that only the public fields are returned if a user does not want to share account fields Verify that only the public fields are returned if a user does not want to share account fields
""" """
data = response.data data = response.data
self.assertEqual(2, len(data)) self.assertEqual(3, len(data))
self.assertEqual(self.user.username, data["username"]) self.assertEqual(self.user.username, data["username"])
self._verify_profile_image_data(data, not requires_parental_consent) self._verify_profile_image_data(data, not requires_parental_consent)
self.assertEqual(account_privacy, data["account_privacy"])
def _verify_full_account_response(self, response, requires_parental_consent=False): def _verify_full_account_response(self, response, requires_parental_consent=False):
""" """
Verify that all account fields are returned (even those that are not shareable). Verify that all account fields are returned (even those that are not shareable).
""" """
data = response.data data = response.data
self.assertEqual(15, len(data)) self.assertEqual(16, len(data))
self.assertEqual(self.user.username, data["username"]) self.assertEqual(self.user.username, data["username"])
self.assertEqual(self.user.first_name + " " + self.user.last_name, data["name"]) self.assertEqual(self.user.first_name + " " + self.user.last_name, data["name"])
self.assertEqual("US", data["country"]) self.assertEqual("US", data["country"])
...@@ -194,6 +197,7 @@ class TestAccountAPI(UserAPITestCase): ...@@ -194,6 +197,7 @@ class TestAccountAPI(UserAPITestCase):
self._verify_profile_image_data(data, not requires_parental_consent) self._verify_profile_image_data(data, not requires_parental_consent)
self.assertEquals(requires_parental_consent, data["requires_parental_consent"]) self.assertEquals(requires_parental_consent, data["requires_parental_consent"])
self.assertEqual([{"code": "en"}], data["language_proficiencies"]) self.assertEqual([{"code": "en"}], data["language_proficiencies"])
self.assertEqual(UserPreference.get_value(self.user, 'account_privacy'), data["account_privacy"])
def test_anonymous_access(self): def test_anonymous_access(self):
""" """
...@@ -235,7 +239,8 @@ class TestAccountAPI(UserAPITestCase): ...@@ -235,7 +239,8 @@ class TestAccountAPI(UserAPITestCase):
""" """
self.different_client.login(username=self.different_user.username, password=self.test_password) self.different_client.login(username=self.different_user.username, password=self.test_password)
self.create_mock_profile(self.user) self.create_mock_profile(self.user)
response = self.send_get(self.different_client) with self.assertNumQueries(9):
response = self.send_get(self.different_client)
self._verify_full_shareable_account_response(response) self._verify_full_shareable_account_response(response)
# Note: using getattr so that the patching works even if there is no configuration. # Note: using getattr so that the patching works even if there is no configuration.
...@@ -249,7 +254,8 @@ class TestAccountAPI(UserAPITestCase): ...@@ -249,7 +254,8 @@ class TestAccountAPI(UserAPITestCase):
""" """
self.different_client.login(username=self.different_user.username, password=self.test_password) self.different_client.login(username=self.different_user.username, password=self.test_password)
self.create_mock_profile(self.user) self.create_mock_profile(self.user)
response = self.send_get(self.different_client) with self.assertNumQueries(9):
response = self.send_get(self.different_client)
self._verify_private_account_response(response) self._verify_private_account_response(response)
@ddt.data( @ddt.data(
...@@ -270,9 +276,9 @@ class TestAccountAPI(UserAPITestCase): ...@@ -270,9 +276,9 @@ class TestAccountAPI(UserAPITestCase):
Confirms that private fields are private, and public/shareable fields are public/shareable Confirms that private fields are private, and public/shareable fields are public/shareable
""" """
if preference_visibility == PRIVATE_VISIBILITY: if preference_visibility == PRIVATE_VISIBILITY:
self._verify_private_account_response(response) self._verify_private_account_response(response, account_privacy=PRIVATE_VISIBILITY)
else: else:
self._verify_full_shareable_account_response(response) self._verify_full_shareable_account_response(response, ALL_USERS_VISIBILITY)
client = self.login_client(api_client, requesting_username) client = self.login_client(api_client, requesting_username)
...@@ -299,9 +305,10 @@ class TestAccountAPI(UserAPITestCase): ...@@ -299,9 +305,10 @@ class TestAccountAPI(UserAPITestCase):
""" """
Internal helper to perform the actual assertions Internal helper to perform the actual assertions
""" """
response = self.send_get(self.client) with self.assertNumQueries(8):
response = self.send_get(self.client)
data = response.data data = response.data
self.assertEqual(15, len(data)) self.assertEqual(16, len(data))
self.assertEqual(self.user.username, data["username"]) self.assertEqual(self.user.username, data["username"])
self.assertEqual(self.user.first_name + " " + self.user.last_name, data["name"]) self.assertEqual(self.user.first_name + " " + self.user.last_name, data["name"])
for empty_field in ("year_of_birth", "level_of_education", "mailing_address", "bio"): for empty_field in ("year_of_birth", "level_of_education", "mailing_address", "bio"):
...@@ -315,6 +322,7 @@ class TestAccountAPI(UserAPITestCase): ...@@ -315,6 +322,7 @@ class TestAccountAPI(UserAPITestCase):
self._verify_profile_image_data(data, False) self._verify_profile_image_data(data, False)
self.assertTrue(data["requires_parental_consent"]) self.assertTrue(data["requires_parental_consent"])
self.assertEqual([], data["language_proficiencies"]) self.assertEqual([], data["language_proficiencies"])
self.assertEqual(None, data["account_privacy"])
self.client.login(username=self.user.username, password=self.test_password) self.client.login(username=self.user.username, password=self.test_password)
verify_get_own_information() verify_get_own_information()
...@@ -336,7 +344,8 @@ class TestAccountAPI(UserAPITestCase): ...@@ -336,7 +344,8 @@ class TestAccountAPI(UserAPITestCase):
legacy_profile.save() legacy_profile.save()
self.client.login(username=self.user.username, password=self.test_password) self.client.login(username=self.user.username, password=self.test_password)
response = self.send_get(self.client) with self.assertNumQueries(8):
response = self.send_get(self.client)
for empty_field in ("level_of_education", "gender", "country", "bio"): for empty_field in ("level_of_education", "gender", "country", "bio"):
self.assertIsNone(response.data[empty_field]) self.assertIsNone(response.data[empty_field])
...@@ -383,6 +392,8 @@ class TestAccountAPI(UserAPITestCase): ...@@ -383,6 +392,8 @@ class TestAccountAPI(UserAPITestCase):
"bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>", "bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>",
"z" * 3001, u"Ensure this field has no more than 3000 characters." "z" * 3001, u"Ensure this field has no more than 3000 characters."
), ),
("account_privacy", ALL_USERS_VISIBILITY),
("account_privacy", PRIVATE_VISIBILITY),
# Note that email is tested below, as it is not immediately updated. # Note that email is tested below, as it is not immediately updated.
# Note that language_proficiencies is tested below as there are multiple error and success conditions. # Note that language_proficiencies is tested below as there are multiple error and success conditions.
) )
...@@ -407,8 +418,9 @@ class TestAccountAPI(UserAPITestCase): ...@@ -407,8 +418,9 @@ class TestAccountAPI(UserAPITestCase):
), ),
error_response.data["field_errors"][field]["developer_message"] error_response.data["field_errors"][field]["developer_message"]
) )
else: elif field != "account_privacy":
# If there are no values that would fail validation, then empty string should be supported. # If there are no values that would fail validation, then empty string should be supported;
# except for account_privacy, which cannot be an empty string.
response = self.send_patch(client, {field: ""}) response = self.send_patch(client, {field: ""})
self.assertEqual("", response.data[field]) self.assertEqual("", response.data[field])
...@@ -662,7 +674,7 @@ class TestAccountAPI(UserAPITestCase): ...@@ -662,7 +674,7 @@ class TestAccountAPI(UserAPITestCase):
response = self.send_get(client) response = self.send_get(client)
if has_full_access: if has_full_access:
data = response.data data = response.data
self.assertEqual(15, len(data)) self.assertEqual(16, len(data))
self.assertEqual(self.user.username, data["username"]) self.assertEqual(self.user.username, data["username"])
self.assertEqual(self.user.first_name + " " + self.user.last_name, data["name"]) self.assertEqual(self.user.first_name + " " + self.user.last_name, data["name"])
self.assertEqual(self.user.email, data["email"]) self.assertEqual(self.user.email, data["email"])
...@@ -675,12 +687,17 @@ class TestAccountAPI(UserAPITestCase): ...@@ -675,12 +687,17 @@ class TestAccountAPI(UserAPITestCase):
self.assertIsNotNone(data["date_joined"]) self.assertIsNotNone(data["date_joined"])
self._verify_profile_image_data(data, False) self._verify_profile_image_data(data, False)
self.assertTrue(data["requires_parental_consent"]) self.assertTrue(data["requires_parental_consent"])
self.assertEqual(ALL_USERS_VISIBILITY, data["account_privacy"])
else: else:
self._verify_private_account_response(response, requires_parental_consent=True) self._verify_private_account_response(
response, requires_parental_consent=True, account_privacy=ALL_USERS_VISIBILITY
)
# Verify that the shared view is still private # Verify that the shared view is still private
response = self.send_get(client, query_parameters='view=shared') response = self.send_get(client, query_parameters='view=shared')
self._verify_private_account_response(response, requires_parental_consent=True) self._verify_private_account_response(
response, requires_parental_consent=True, account_privacy=ALL_USERS_VISIBILITY
)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
......
...@@ -96,6 +96,8 @@ class AccountView(APIView): ...@@ -96,6 +96,8 @@ class AccountView(APIView):
requiring parental consent. requiring parental consent.
* username: The username associated with the account. * username: The username associated with the account.
* year_of_birth: The year the user was born, as an integer, or null. * year_of_birth: The year the user was born, as an integer, or null.
* account_privacy: The user's setting for sharing her personal
profile. Possible values are "all_users" or "private".
For all text fields, plain text instead of HTML is supported. The For all text fields, plain text instead of HTML is supported. The
data is stored exactly as specified. Clients must HTML escape data is stored exactly as specified. Clients must HTML escape
......
...@@ -45,7 +45,7 @@ def get_user_preference(requesting_user, preference_key, username=None): ...@@ -45,7 +45,7 @@ def get_user_preference(requesting_user, preference_key, username=None):
UserNotAuthorized: the requesting_user does not have access to the user preference. UserNotAuthorized: the requesting_user does not have access to the user preference.
UserAPIInternalError: the operation failed due to an unexpected error. UserAPIInternalError: the operation failed due to an unexpected error.
""" """
existing_user = _get_user(requesting_user, username, allow_staff=True) existing_user = _get_authorized_user(requesting_user, username, allow_staff=True)
return UserPreference.get_value(existing_user, preference_key) return UserPreference.get_value(existing_user, preference_key)
...@@ -68,7 +68,7 @@ def get_user_preferences(requesting_user, username=None): ...@@ -68,7 +68,7 @@ def get_user_preferences(requesting_user, username=None):
UserNotAuthorized: the requesting_user does not have access to the user preference. UserNotAuthorized: the requesting_user does not have access to the user preference.
UserAPIInternalError: the operation failed due to an unexpected error. UserAPIInternalError: the operation failed due to an unexpected error.
""" """
existing_user = _get_user(requesting_user, username, allow_staff=True) existing_user = _get_authorized_user(requesting_user, username, allow_staff=True)
# Django Rest Framework V3 uses the current request to version # Django Rest Framework V3 uses the current request to version
# hyperlinked URLS, so we need to retrieve the request and pass # hyperlinked URLS, so we need to retrieve the request and pass
...@@ -84,8 +84,8 @@ def get_user_preferences(requesting_user, username=None): ...@@ -84,8 +84,8 @@ def get_user_preferences(requesting_user, username=None):
@intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError]) @intercept_errors(UserAPIInternalError, ignore_errors=[UserAPIRequestError])
def update_user_preferences(requesting_user, update, username=None): def update_user_preferences(requesting_user, update, user=None):
"""Update the user preferences for the given username. """Update the user preferences for the given user.
Note: Note:
It is up to the caller of this method to enforce the contract that this method is only called It is up to the caller of this method to enforce the contract that this method is only called
...@@ -98,8 +98,8 @@ def update_user_preferences(requesting_user, update, username=None): ...@@ -98,8 +98,8 @@ def update_user_preferences(requesting_user, update, username=None):
Some notes: Some notes:
Values are expected to be strings. Non-string values will be converted to strings. Values are expected to be strings. Non-string values will be converted to strings.
Null values for a preference will be treated as a request to delete the key in question. Null values for a preference will be treated as a request to delete the key in question.
username (str): Optional username specifying which account should be updated. If not specified, user (str/User): Optional, either username string or user object specifying which account should be updated.
`requesting_user.username` is assumed. If not specified, `requesting_user.username` is assumed.
Raises: Raises:
UserNotFound: no user with username `username` exists (or `requesting_user.username` if UserNotFound: no user with username `username` exists (or `requesting_user.username` if
...@@ -110,7 +110,10 @@ def update_user_preferences(requesting_user, update, username=None): ...@@ -110,7 +110,10 @@ def update_user_preferences(requesting_user, update, username=None):
PreferenceUpdateError: the operation failed when performing the update. PreferenceUpdateError: the operation failed when performing the update.
UserAPIInternalError: the operation failed due to an unexpected error. UserAPIInternalError: the operation failed due to an unexpected error.
""" """
existing_user = _get_user(requesting_user, username) if not user or isinstance(user, basestring):
user = _get_authorized_user(requesting_user, user)
else:
_check_authorized(requesting_user, user.username)
# First validate each preference setting # First validate each preference setting
errors = {} errors = {}
...@@ -119,7 +122,7 @@ def update_user_preferences(requesting_user, update, username=None): ...@@ -119,7 +122,7 @@ def update_user_preferences(requesting_user, update, username=None):
preference_value = update[preference_key] preference_value = update[preference_key]
if preference_value is not None: if preference_value is not None:
try: try:
serializer = create_user_preference_serializer(existing_user, preference_key, preference_value) serializer = create_user_preference_serializer(user, preference_key, preference_value)
validate_user_preference_serializer(serializer, preference_key, preference_value) validate_user_preference_serializer(serializer, preference_key, preference_value)
serializers[preference_key] = serializer serializers[preference_key] = serializer
except PreferenceValidationError as error: except PreferenceValidationError as error:
...@@ -168,7 +171,7 @@ def set_user_preference(requesting_user, preference_key, preference_value, usern ...@@ -168,7 +171,7 @@ def set_user_preference(requesting_user, preference_key, preference_value, usern
PreferenceUpdateError: the operation failed when performing the update. PreferenceUpdateError: the operation failed when performing the update.
UserAPIInternalError: the operation failed due to an unexpected error. UserAPIInternalError: the operation failed due to an unexpected error.
""" """
existing_user = _get_user(requesting_user, username) existing_user = _get_authorized_user(requesting_user, username)
serializer = create_user_preference_serializer(existing_user, preference_key, preference_value) serializer = create_user_preference_serializer(existing_user, preference_key, preference_value)
validate_user_preference_serializer(serializer, preference_key, preference_value) validate_user_preference_serializer(serializer, preference_key, preference_value)
try: try:
...@@ -203,7 +206,7 @@ def delete_user_preference(requesting_user, preference_key, username=None): ...@@ -203,7 +206,7 @@ def delete_user_preference(requesting_user, preference_key, username=None):
PreferenceUpdateError: the operation failed when performing the update. PreferenceUpdateError: the operation failed when performing the update.
UserAPIInternalError: the operation failed due to an unexpected error. UserAPIInternalError: the operation failed due to an unexpected error.
""" """
existing_user = _get_user(requesting_user, username) existing_user = _get_authorized_user(requesting_user, username)
try: try:
user_preference = UserPreference.objects.get(user=existing_user, key=preference_key) user_preference = UserPreference.objects.get(user=existing_user, key=preference_key)
except ObjectDoesNotExist: except ObjectDoesNotExist:
...@@ -298,25 +301,32 @@ def _track_update_email_opt_in(user_id, organization, opt_in): ...@@ -298,25 +301,32 @@ def _track_update_email_opt_in(user_id, organization, opt_in):
) )
def _get_user(requesting_user, username=None, allow_staff=False): def _get_authorized_user(requesting_user, username=None, allow_staff=False):
""" """
Helper method to return the user for a given username. Helper method to return the authorized user for a given username.
If username is not provided, requesting_user.username is assumed. If username is not provided, requesting_user.username is assumed.
""" """
if username is None: if username is None:
username = requesting_user.username username = requesting_user.username
try: try:
existing_user = User.objects.get(username=username) existing_user = User.objects.get(username=username)
except ObjectDoesNotExist: except ObjectDoesNotExist:
raise UserNotFound() raise UserNotFound()
_check_authorized(requesting_user, username, allow_staff)
return existing_user
def _check_authorized(requesting_user, username, allow_staff=False):
"""
Helper method that raises UserNotAuthorized if requesting user
is not owner user or is not staff if access to staff is given
(i.e. 'allow_staff' = true)
"""
if requesting_user.username != username: if requesting_user.username != username:
if not requesting_user.is_staff or not allow_staff: if not requesting_user.is_staff or not allow_staff:
raise UserNotAuthorized() raise UserNotAuthorized()
return existing_user
def create_user_preference_serializer(user, preference_key, preference_value): def create_user_preference_serializer(user, preference_key, preference_value):
"""Creates a serializer for the specified user preference. """Creates a serializer for the specified user preference.
......
...@@ -181,6 +181,34 @@ class TestPreferenceAPI(TestCase): ...@@ -181,6 +181,34 @@ class TestPreferenceAPI(TestCase):
"new_value" "new_value"
) )
def test_update_user_preferences_with_username(self):
"""
Verifies the basic behavior of update_user_preferences when passed
username string.
"""
update_data = {
self.test_preference_key: "new_value"
}
update_user_preferences(self.user, update_data, user=self.user.username)
self.assertEqual(
get_user_preference(self.user, self.test_preference_key),
"new_value"
)
def test_update_user_preferences_with_user(self):
"""
Verifies the basic behavior of update_user_preferences when passed
user object.
"""
update_data = {
self.test_preference_key: "new_value"
}
update_user_preferences(self.user, update_data, user=self.user)
self.assertEqual(
get_user_preference(self.user, self.test_preference_key),
"new_value"
)
@patch('openedx.core.djangoapps.user_api.models.UserPreference.delete') @patch('openedx.core.djangoapps.user_api.models.UserPreference.delete')
@patch('openedx.core.djangoapps.user_api.models.UserPreference.save') @patch('openedx.core.djangoapps.user_api.models.UserPreference.save')
def test_update_user_preferences_errors(self, user_preference_save, user_preference_delete): def test_update_user_preferences_errors(self, user_preference_save, user_preference_delete):
...@@ -191,16 +219,16 @@ class TestPreferenceAPI(TestCase): ...@@ -191,16 +219,16 @@ class TestPreferenceAPI(TestCase):
self.test_preference_key: "new_value" self.test_preference_key: "new_value"
} }
with self.assertRaises(UserNotFound): with self.assertRaises(UserNotFound):
update_user_preferences(self.user, update_data, username="no_such_user") update_user_preferences(self.user, update_data, user="no_such_user")
with self.assertRaises(UserNotFound): with self.assertRaises(UserNotFound):
update_user_preferences(self.no_such_user, update_data) update_user_preferences(self.no_such_user, update_data)
with self.assertRaises(UserNotAuthorized): with self.assertRaises(UserNotAuthorized):
update_user_preferences(self.staff_user, update_data, username=self.user.username) update_user_preferences(self.staff_user, update_data, user=self.user.username)
with self.assertRaises(UserNotAuthorized): with self.assertRaises(UserNotAuthorized):
update_user_preferences(self.different_user, update_data, username=self.user.username) update_user_preferences(self.different_user, update_data, user=self.user.username)
too_long_key = "x" * 256 too_long_key = "x" * 256
with self.assertRaises(PreferenceValidationError) as context_manager: with self.assertRaises(PreferenceValidationError) as context_manager:
......
...@@ -118,7 +118,7 @@ class PreferencesView(APIView): ...@@ -118,7 +118,7 @@ class PreferencesView(APIView):
) )
try: try:
with transaction.commit_on_success(): with transaction.commit_on_success():
update_user_preferences(request.user, request.data, username=username) update_user_preferences(request.user, request.data, user=username)
except UserNotAuthorized: except UserNotAuthorized:
return Response(status=status.HTTP_403_FORBIDDEN) return Response(status=status.HTTP_403_FORBIDDEN)
except UserNotFound: except UserNotFound:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment