Commit 450c0e34 by cahrens

Enforce content_type of 'application/merge-patch+json' for merge patch.

parent 09c607c6
import unittest import unittest
import ddt import ddt
import json
from django.test import TestCase from django.test import TestCase
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
...@@ -20,8 +21,8 @@ class TestAccountAPI(APITestCase): ...@@ -20,8 +21,8 @@ class TestAccountAPI(APITestCase):
self.anonymous_client = APIClient() self.anonymous_client = APIClient()
self.bad_user = UserFactory.create(password=TEST_PASSWORD) self.different_user = UserFactory.create(password=TEST_PASSWORD)
self.bad_client = APIClient() self.different_client = APIClient()
self.staff_user = UserFactory(is_staff=True, password=TEST_PASSWORD) self.staff_user = UserFactory(is_staff=True, password=TEST_PASSWORD)
self.staff_client = APIClient() self.staff_client = APIClient()
...@@ -39,12 +40,18 @@ class TestAccountAPI(APITestCase): ...@@ -39,12 +40,18 @@ class TestAccountAPI(APITestCase):
self.accounts_base_uri = reverse("accounts_api", kwargs={'username': self.user.username}) self.accounts_base_uri = reverse("accounts_api", kwargs={'username': self.user.username})
def test_get_account_anonymous_user(self): def test_get_account_anonymous_user(self):
"""
Test that an anonymous client (not logged in) cannot call get.
"""
response = self.anonymous_client.get(self.accounts_base_uri) response = self.anonymous_client.get(self.accounts_base_uri)
self.assert_status_code(401, response) self.assert_status_code(401, response)
def test_get_account_bad_user(self): def test_get_account_different_user(self):
self.bad_client.login(username=self.bad_user.username, password=TEST_PASSWORD) """
response = self.bad_client.get(self.accounts_base_uri) Test that a client (logged in) cannot get the account information for a different client.
"""
self.different_client.login(username=self.different_user.username, password=TEST_PASSWORD)
response = self.different_client.get(self.accounts_base_uri)
self.assert_status_code(404, response) self.assert_status_code(404, response)
@ddt.data( @ddt.data(
...@@ -53,6 +60,10 @@ class TestAccountAPI(APITestCase): ...@@ -53,6 +60,10 @@ class TestAccountAPI(APITestCase):
) )
@ddt.unpack @ddt.unpack
def test_get_account(self, api_client, user): def test_get_account(self, api_client, user):
"""
Test that a client (logged in) can get her own account information. Also verifies that a "is_staff"
user can get the account information for other users.
"""
client = self.login_client(api_client, user) client = self.login_client(api_client, user)
response = client.get(self.accounts_base_uri) response = client.get(self.accounts_base_uri)
...@@ -98,17 +109,18 @@ class TestAccountAPI(APITestCase): ...@@ -98,17 +109,18 @@ class TestAccountAPI(APITestCase):
def test_patch_account( def test_patch_account(
self, api_client, user, field, value, fails_validation_value=None, developer_validation_message=None self, api_client, user, field, value, fails_validation_value=None, developer_validation_message=None
): ):
"""
Test the behavior of patch, when using the correct content_type.
"""
client = self.login_client(api_client, user) client = self.login_client(api_client, user)
patch_response = client.patch(self.accounts_base_uri, data={field: value}) self.send_patch(client, {field: value})
self.assert_status_code(204, patch_response)
get_response = client.get(self.accounts_base_uri) get_response = client.get(self.accounts_base_uri)
self.assert_status_code(200, get_response) self.assert_status_code(200, get_response)
self.assertEqual(value, get_response.data[field]) self.assertEqual(value, get_response.data[field])
if fails_validation_value: if fails_validation_value:
error_response = client.patch(self.accounts_base_uri, data={field: fails_validation_value}) error_response = self.send_patch(client, {field: fails_validation_value}, expected_status=400)
self.assert_status_code(400, error_response)
self.assertEqual( self.assertEqual(
"Value '{0}' is not valid for field '{1}'.".format(fails_validation_value, field), "Value '{0}' is not valid for field '{1}'.".format(fails_validation_value, field),
error_response.data["field_errors"][field]["user_message"] error_response.data["field_errors"][field]["user_message"]
...@@ -119,8 +131,7 @@ class TestAccountAPI(APITestCase): ...@@ -119,8 +131,7 @@ class TestAccountAPI(APITestCase):
) )
else: else:
# If there are no values that would fail validation, then empty string should be supported. # If there are no values that would fail validation, then empty string should be supported.
patch_response = client.patch(self.accounts_base_uri, data={field: ""}) self.send_patch(client, {field: ""})
self.assert_status_code(204, patch_response)
get_response = client.get(self.accounts_base_uri) get_response = client.get(self.accounts_base_uri)
self.assert_status_code(200, get_response) self.assert_status_code(200, get_response)
...@@ -132,6 +143,9 @@ class TestAccountAPI(APITestCase): ...@@ -132,6 +143,9 @@ class TestAccountAPI(APITestCase):
) )
@ddt.unpack @ddt.unpack
def test_patch_account_noneditable(self, api_client, user): def test_patch_account_noneditable(self, api_client, user):
"""
Tests the behavior of patch when a read-only field is attempted to be edited.
"""
client = self.login_client(api_client, user) client = self.login_client(api_client, user)
def verify_error_response(field_name, data): def verify_error_response(field_name, data):
...@@ -143,8 +157,7 @@ class TestAccountAPI(APITestCase): ...@@ -143,8 +157,7 @@ class TestAccountAPI(APITestCase):
) )
for field_name in ["username", "email", "date_joined", "name"]: for field_name in ["username", "email", "date_joined", "name"]:
response = client.patch(self.accounts_base_uri, data={field_name: "will_error", "gender": "f"}) response = self.send_patch(client, {field_name: "will_error", "gender": "f"}, expected_status=400)
self.assert_status_code(400, response)
verify_error_response(field_name, response.data) verify_error_response(field_name, response.data)
# Make sure that gender did not change. # Make sure that gender did not change.
...@@ -152,18 +165,35 @@ class TestAccountAPI(APITestCase): ...@@ -152,18 +165,35 @@ class TestAccountAPI(APITestCase):
self.assertEqual("m", response.data["gender"]) self.assertEqual("m", response.data["gender"])
# Test error message with multiple read-only items # Test error message with multiple read-only items
response = client.patch(self.accounts_base_uri, data={"username": "will_error", "email": "xx"}) response = self.send_patch(client, {"username": "will_error", "email": "xx"}, expected_status=400)
self.assert_status_code(400, response)
self.assertEqual(2, len(response.data["field_errors"])) self.assertEqual(2, len(response.data["field_errors"]))
verify_error_response("username", response.data) verify_error_response("username", response.data)
verify_error_response("email", response.data) verify_error_response("email", response.data)
def test_patch_bad_content_type(self):
"""
Test the behavior of patch when an incorrect content_type is specified.
"""
self.client.login(username=self.user.username, password=TEST_PASSWORD)
self.send_patch(self.client, {}, content_type="application/json", expected_status=415)
self.send_patch(self.client, {}, content_type="application/xml", expected_status=415)
def assert_status_code(self, expected_status_code, response): def assert_status_code(self, expected_status_code, response):
"""Assert that the given response has the expected status code""" """Assert that the given response has the expected status code"""
self.assertEqual(expected_status_code, response.status_code) self.assertEqual(expected_status_code, response.status_code)
def login_client(self, api_client, user): def login_client(self, api_client, user):
"""Helper method for getting the client and user and logging in. Returns client. """
client = getattr(self, api_client) client = getattr(self, api_client)
user = getattr(self, user) user = getattr(self, user)
client.login(username=user.username, password=TEST_PASSWORD) client.login(username=user.username, password=TEST_PASSWORD)
return client return client
def send_patch(self, client, json_data, content_type="application/merge-patch+json", expected_status=204):
"""
Helper method for sending a patch to the server, defaulting to application/merge-patch+json content_type.
Verifies the expected status and returns the response.
"""
response = client.patch(self.accounts_base_uri, data=json.dumps(json_data), content_type=content_type)
self.assert_status_code(expected_status, response)
return response
...@@ -7,23 +7,25 @@ from rest_framework.response import Response ...@@ -7,23 +7,25 @@ from rest_framework.response import Response
from rest_framework import status from rest_framework import status
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
from rest_framework import permissions from rest_framework import permissions
from rest_framework import parsers
from student.models import UserProfile from student.models import UserProfile
from openedx.core.djangoapps.user_api.accounts.serializers import AccountLegacyProfileSerializer, AccountUserSerializer from openedx.core.djangoapps.user_api.accounts.serializers import AccountLegacyProfileSerializer, AccountUserSerializer
from openedx.core.lib.api.permissions import IsUserInUrlOrStaff from openedx.core.lib.api.permissions import IsUserInUrlOrStaff
from openedx.core.lib.api.parsers import MergePatchParser
class AccountView(APIView): class AccountView(APIView):
""" """
**Use Cases** **Use Cases**
Get or update the user's account information. Get or update the user's account information. Updates are only supported through merge patch.
**Example Requests**: **Example Requests**:
GET /api/user/v0/accounts/{username}/ GET /api/user/v0/accounts/{username}/
PATCH /api/user/v0/accounts/{username}/ PATCH /api/user/v0/accounts/{username}/ with content_type "application/merge-patch+json"
**Response Values for GET** **Response Values for GET**
...@@ -64,10 +66,12 @@ class AccountView(APIView): ...@@ -64,10 +66,12 @@ class AccountView(APIView):
**Response for PATCH** **Response for PATCH**
Returns a 204 status if successful, with no additional content. Returns a 204 status if successful, with no additional content.
If "application/merge-patch+json" is not the specified content_type, returns a 415 status.
""" """
authentication_classes = (OAuth2Authentication, SessionAuthentication) authentication_classes = (OAuth2Authentication, SessionAuthentication)
permission_classes = (permissions.IsAuthenticated, IsUserInUrlOrStaff) permission_classes = (permissions.IsAuthenticated, IsUserInUrlOrStaff)
parser_classes = (MergePatchParser,)
def get(self, request, username): def get(self, request, username):
""" """
...@@ -82,6 +86,10 @@ class AccountView(APIView): ...@@ -82,6 +86,10 @@ class AccountView(APIView):
def patch(self, request, username): def patch(self, request, username):
""" """
PATCH /api/user/v0/accounts/{username}/ PATCH /api/user/v0/accounts/{username}/
Note that this implementation is the "merge patch" implementation proposed in
https://tools.ietf.org/html/rfc7396. The content_type must be "application/merge-patch+json" or
else an error response with status code 415 will be returned.
""" """
existing_user, existing_user_profile = self._get_user_and_profile(username) existing_user, existing_user_profile = self._get_user_and_profile(username)
......
from rest_framework import parsers
class MergePatchParser(parsers.JSONParser):
"""
Custom parser to be used with the "merge patch" implementation (https://tools.ietf.org/html/rfc7396).
"""
media_type = 'application/merge-patch+json'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment