Commit 22471dc5 by Kyle McCormick

Change course_summaries() to use POST instead of GET

POST method allows large number of course ID arguments to be passed as
data, while GET method is restricted by URL length.

EDUCATOR-464
parent e64dc8ca
......@@ -6,3 +6,4 @@ Dylan Rhodes <dylanr@stanford.edu>
Dmitry Viskov <dmitry.viskov@webenterprise.ru>
Tyler Hallada <thallada@edx.org>
Braden MacDonald <braden@opencraft.com>
Kyle McCormick <kylemccor@gmail.com>
......@@ -28,6 +28,9 @@ class Client(object):
DATE_FORMAT = '%Y-%m-%d'
DATETIME_FORMAT = DATE_FORMAT + 'T%H%M%S'
METHOD_GET = 'GET'
METHOD_POST = 'POST'
def __init__(self, base_url, auth_token=None, timeout=0.25):
"""
Initialize the client.
......@@ -63,17 +66,37 @@ class Client(object):
Raises: ClientError if the resource cannot be retrieved for any reason.
"""
response = self._request(resource, timeout=timeout, data_format=data_format)
return self._get_or_post(
self.METHOD_GET,
resource,
timeout=timeout,
data_format=data_format
)
def post(self, resource, post_data=None, timeout=None, data_format=DF.JSON):
"""
Retrieve the data for POST request.
if data_format == DF.CSV:
return response.text
Arguments:
try:
return response.json()
except ValueError:
message = 'Unable to decode JSON response'
log.exception(message)
raise ClientError(message)
resource (str): Path in the form of slash separated strings.
post_data (dict): Dictionary containing POST data.
timeout (float): Continue to attempt to retrieve a resource for this many seconds before giving up and
raising an error.
data_format (str): Format in which data should be returned
Returns: API response data in specified data_format
Raises: ClientError if the resource cannot be retrieved for any reason.
"""
return self._get_or_post(
self.METHOD_POST,
resource,
post_data=post_data,
timeout=timeout,
data_format=data_format
)
def has_resource(self, resource, timeout=None):
"""
......@@ -91,13 +114,32 @@ class Client(object):
"""
try:
self._request(resource, timeout=timeout)
self._request(self.METHOD_GET, resource, timeout=timeout)
return True
except ClientError:
return False
def _get_or_post(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON):
response = self._request(
method,
resource,
post_data=post_data,
timeout=timeout,
data_format=data_format
)
if data_format == DF.CSV:
return response.text
try:
return response.json()
except ValueError:
message = 'Unable to decode JSON response'
log.exception(message)
raise ClientError(message)
# pylint: disable=no-member
def _request(self, resource, timeout=None, data_format=DF.JSON):
def _request(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON):
if timeout is None:
timeout = self.timeout
......@@ -114,7 +156,17 @@ class Client(object):
try:
uri = '{0}/{1}'.format(self.base_url, resource)
response = requests.get(uri, headers=headers, timeout=timeout)
if method == self.METHOD_GET:
response = requests.get(uri, headers=headers, timeout=timeout)
elif method == self.METHOD_POST:
response = requests.post(uri, data=(post_data or {}), headers=headers, timeout=timeout)
else:
raise ValueError(
'Invalid \'method\' argument: expected {0} or {1}, got {2}'.format(
self.METHOD_GET, self.METHOD_POST, method
)
)
status = response.status_code
if status != requests.codes.ok:
......
import urllib
import analyticsclient.constants.data_format as DF
......@@ -27,15 +25,12 @@ class CourseSummaries(object):
exclude: Array of fields to exclude from response. Default is to not exclude any fields.
programs: If included in the query parameters, will include the programs array in the response.
"""
query_params = {}
for query_arg, data in zip(['course_ids', 'fields', 'exclude', 'programs'],
[course_ids, fields, exclude, programs]):
post_data = {}
for param_name, data in zip(['course_ids', 'fields', 'exclude', 'programs'],
[course_ids, fields, exclude, programs]):
if data:
query_params[query_arg] = ','.join(data)
post_data[param_name] = data
path = 'course_summaries/'
querystring = urllib.urlencode(query_params)
if querystring:
path += '?{0}'.format(querystring)
return self.client.get(path, data_format=data_format)
return self.client.post(path, post_data=post_data, data_format=data_format)
......@@ -34,6 +34,7 @@ class APIListTestCase(object):
# Override in the subclass:
endpoint = 'list'
id_field = 'id'
uses_post_method = False
def setUp(self):
"""Set up the test case."""
......@@ -58,17 +59,25 @@ class APIListTestCase(object):
def kwarg_test(self, **kwargs):
"""Construct URL with given query parameters and check if it is what we expect."""
httpretty.reset()
uri_template = '{uri}?'
for key in kwargs:
uri_template += '%s={%s}' % (key, key)
uri = uri_template.format(uri=self.base_uri, **kwargs)
httpretty.register_uri(httpretty.GET, uri, body='{}')
getattr(self.client_class, self.endpoint)(**kwargs)
self.verify_last_querystring_equal(self.expected_query(**kwargs))
if self.uses_post_method:
httpretty.register_uri(httpretty.POST, self.base_uri, body='{}')
getattr(self.client_class, self.endpoint)(**kwargs)
self.assertDictEqual(httpretty.last_request().parsed_body or {}, kwargs)
else:
uri_template = '{uri}?'
for key in kwargs:
uri_template += '%s={%s}' % (key, key)
uri = uri_template.format(uri=self.base_uri, **kwargs)
httpretty.register_uri(httpretty.GET, uri, body='{}')
getattr(self.client_class, self.endpoint)(**kwargs)
self.verify_last_querystring_equal(self.expected_query(**kwargs))
def test_all_items_url(self):
"""Endpoint can be called without parameters."""
httpretty.register_uri(httpretty.GET, self.base_uri, body='{}')
httpretty.register_uri(
httpretty.POST if self.uses_post_method else httpretty.GET,
self.base_uri, body='{}'
)
getattr(self.client_class, self.endpoint)()
@ddt.data(
......
......@@ -46,6 +46,11 @@ class ClientTests(ClientTestCase):
httpretty.register_uri(httpretty.GET, self.test_url, body=json.dumps(data))
self.assertEquals(self.client.get(self.test_endpoint), data)
def test_post(self):
data = {'foo': 'bar'}
httpretty.register_uri(httpretty.POST, self.test_url, body=json.dumps(data))
self.assertEquals(self.client.post(self.test_endpoint), data)
def test_get_invalid_response_body(self):
""" Verify that client raises a ClientError if the response body cannot be properly parsed. """
......@@ -71,7 +76,13 @@ class ClientTests(ClientTestCase):
timeout = None
headers = {'Accept': 'application/json'}
self.assertRaises(TimeoutError, self.client._request, self.test_endpoint, timeout=timeout)
self.assertRaises(
TimeoutError,
self.client._request,
self.client.METHOD_GET,
self.test_endpoint,
timeout=timeout
)
msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, self.client.timeout)
lc.check(('analyticsclient.client', 'ERROR', msg))
lc.clear()
......@@ -79,7 +90,13 @@ class ClientTests(ClientTestCase):
mock_get.reset_mock()
timeout = 10
self.assertRaises(TimeoutError, self.client._request, self.test_endpoint, timeout=timeout)
self.assertRaises(
TimeoutError,
self.client._request,
self.client.METHOD_GET,
self.test_endpoint,
timeout=timeout
)
mock_get.assert_called_once_with(url, headers=headers, timeout=timeout)
msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, timeout)
lc.check(('analyticsclient.client', 'ERROR', msg))
......@@ -100,3 +117,11 @@ class ClientTests(ClientTestCase):
response = self.client.get(self.test_endpoint, data_format=data_format.JSON)
self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json')
self.assertDictEqual(response, {})
def test_unsupported_method(self):
self.assertRaises(
ValueError,
self.client._request,
'PATCH',
self.test_endpoint
)
......@@ -9,6 +9,7 @@ class CourseSummariesTests(APIListTestCase, ClientTestCase):
endpoint = 'course_summaries'
id_field = 'course_ids'
uses_post_method = True
@ddt.data(
['123'],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment