Commit 22471dc5 by Kyle McCormick

Change course_summaries() to use POST instead of GET

POST method allows large number of course ID arguments to be passed as
data, while GET method is restricted by URL length.

EDUCATOR-464
parent e64dc8ca
...@@ -6,3 +6,4 @@ Dylan Rhodes <dylanr@stanford.edu> ...@@ -6,3 +6,4 @@ Dylan Rhodes <dylanr@stanford.edu>
Dmitry Viskov <dmitry.viskov@webenterprise.ru> Dmitry Viskov <dmitry.viskov@webenterprise.ru>
Tyler Hallada <thallada@edx.org> Tyler Hallada <thallada@edx.org>
Braden MacDonald <braden@opencraft.com> Braden MacDonald <braden@opencraft.com>
Kyle McCormick <kylemccor@gmail.com>
...@@ -28,6 +28,9 @@ class Client(object): ...@@ -28,6 +28,9 @@ class Client(object):
DATE_FORMAT = '%Y-%m-%d' DATE_FORMAT = '%Y-%m-%d'
DATETIME_FORMAT = DATE_FORMAT + 'T%H%M%S' DATETIME_FORMAT = DATE_FORMAT + 'T%H%M%S'
METHOD_GET = 'GET'
METHOD_POST = 'POST'
def __init__(self, base_url, auth_token=None, timeout=0.25): def __init__(self, base_url, auth_token=None, timeout=0.25):
""" """
Initialize the client. Initialize the client.
...@@ -63,17 +66,37 @@ class Client(object): ...@@ -63,17 +66,37 @@ class Client(object):
Raises: ClientError if the resource cannot be retrieved for any reason. Raises: ClientError if the resource cannot be retrieved for any reason.
""" """
response = self._request(resource, timeout=timeout, data_format=data_format) return self._get_or_post(
self.METHOD_GET,
resource,
timeout=timeout,
data_format=data_format
)
def post(self, resource, post_data=None, timeout=None, data_format=DF.JSON):
"""
Retrieve the data for POST request.
if data_format == DF.CSV: Arguments:
return response.text
try: resource (str): Path in the form of slash separated strings.
return response.json() post_data (dict): Dictionary containing POST data.
except ValueError: timeout (float): Continue to attempt to retrieve a resource for this many seconds before giving up and
message = 'Unable to decode JSON response' raising an error.
log.exception(message) data_format (str): Format in which data should be returned
raise ClientError(message)
Returns: API response data in specified data_format
Raises: ClientError if the resource cannot be retrieved for any reason.
"""
return self._get_or_post(
self.METHOD_POST,
resource,
post_data=post_data,
timeout=timeout,
data_format=data_format
)
def has_resource(self, resource, timeout=None): def has_resource(self, resource, timeout=None):
""" """
...@@ -91,13 +114,32 @@ class Client(object): ...@@ -91,13 +114,32 @@ class Client(object):
""" """
try: try:
self._request(resource, timeout=timeout) self._request(self.METHOD_GET, resource, timeout=timeout)
return True return True
except ClientError: except ClientError:
return False return False
def _get_or_post(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON):
response = self._request(
method,
resource,
post_data=post_data,
timeout=timeout,
data_format=data_format
)
if data_format == DF.CSV:
return response.text
try:
return response.json()
except ValueError:
message = 'Unable to decode JSON response'
log.exception(message)
raise ClientError(message)
# pylint: disable=no-member # pylint: disable=no-member
def _request(self, resource, timeout=None, data_format=DF.JSON): def _request(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON):
if timeout is None: if timeout is None:
timeout = self.timeout timeout = self.timeout
...@@ -114,7 +156,17 @@ class Client(object): ...@@ -114,7 +156,17 @@ class Client(object):
try: try:
uri = '{0}/{1}'.format(self.base_url, resource) uri = '{0}/{1}'.format(self.base_url, resource)
response = requests.get(uri, headers=headers, timeout=timeout)
if method == self.METHOD_GET:
response = requests.get(uri, headers=headers, timeout=timeout)
elif method == self.METHOD_POST:
response = requests.post(uri, data=(post_data or {}), headers=headers, timeout=timeout)
else:
raise ValueError(
'Invalid \'method\' argument: expected {0} or {1}, got {2}'.format(
self.METHOD_GET, self.METHOD_POST, method
)
)
status = response.status_code status = response.status_code
if status != requests.codes.ok: if status != requests.codes.ok:
......
import urllib
import analyticsclient.constants.data_format as DF import analyticsclient.constants.data_format as DF
...@@ -27,15 +25,12 @@ class CourseSummaries(object): ...@@ -27,15 +25,12 @@ class CourseSummaries(object):
exclude: Array of fields to exclude from response. Default is to not exclude any fields. exclude: Array of fields to exclude from response. Default is to not exclude any fields.
programs: If included in the query parameters, will include the programs array in the response. programs: If included in the query parameters, will include the programs array in the response.
""" """
query_params = {} post_data = {}
for query_arg, data in zip(['course_ids', 'fields', 'exclude', 'programs'], for param_name, data in zip(['course_ids', 'fields', 'exclude', 'programs'],
[course_ids, fields, exclude, programs]): [course_ids, fields, exclude, programs]):
if data: if data:
query_params[query_arg] = ','.join(data) post_data[param_name] = data
path = 'course_summaries/' path = 'course_summaries/'
querystring = urllib.urlencode(query_params)
if querystring:
path += '?{0}'.format(querystring)
return self.client.get(path, data_format=data_format) return self.client.post(path, post_data=post_data, data_format=data_format)
...@@ -34,6 +34,7 @@ class APIListTestCase(object): ...@@ -34,6 +34,7 @@ class APIListTestCase(object):
# Override in the subclass: # Override in the subclass:
endpoint = 'list' endpoint = 'list'
id_field = 'id' id_field = 'id'
uses_post_method = False
def setUp(self): def setUp(self):
"""Set up the test case.""" """Set up the test case."""
...@@ -58,17 +59,25 @@ class APIListTestCase(object): ...@@ -58,17 +59,25 @@ class APIListTestCase(object):
def kwarg_test(self, **kwargs): def kwarg_test(self, **kwargs):
"""Construct URL with given query parameters and check if it is what we expect.""" """Construct URL with given query parameters and check if it is what we expect."""
httpretty.reset() httpretty.reset()
uri_template = '{uri}?' if self.uses_post_method:
for key in kwargs: httpretty.register_uri(httpretty.POST, self.base_uri, body='{}')
uri_template += '%s={%s}' % (key, key) getattr(self.client_class, self.endpoint)(**kwargs)
uri = uri_template.format(uri=self.base_uri, **kwargs) self.assertDictEqual(httpretty.last_request().parsed_body or {}, kwargs)
httpretty.register_uri(httpretty.GET, uri, body='{}') else:
getattr(self.client_class, self.endpoint)(**kwargs) uri_template = '{uri}?'
self.verify_last_querystring_equal(self.expected_query(**kwargs)) for key in kwargs:
uri_template += '%s={%s}' % (key, key)
uri = uri_template.format(uri=self.base_uri, **kwargs)
httpretty.register_uri(httpretty.GET, uri, body='{}')
getattr(self.client_class, self.endpoint)(**kwargs)
self.verify_last_querystring_equal(self.expected_query(**kwargs))
def test_all_items_url(self): def test_all_items_url(self):
"""Endpoint can be called without parameters.""" """Endpoint can be called without parameters."""
httpretty.register_uri(httpretty.GET, self.base_uri, body='{}') httpretty.register_uri(
httpretty.POST if self.uses_post_method else httpretty.GET,
self.base_uri, body='{}'
)
getattr(self.client_class, self.endpoint)() getattr(self.client_class, self.endpoint)()
@ddt.data( @ddt.data(
......
...@@ -46,6 +46,11 @@ class ClientTests(ClientTestCase): ...@@ -46,6 +46,11 @@ class ClientTests(ClientTestCase):
httpretty.register_uri(httpretty.GET, self.test_url, body=json.dumps(data)) httpretty.register_uri(httpretty.GET, self.test_url, body=json.dumps(data))
self.assertEquals(self.client.get(self.test_endpoint), data) self.assertEquals(self.client.get(self.test_endpoint), data)
def test_post(self):
data = {'foo': 'bar'}
httpretty.register_uri(httpretty.POST, self.test_url, body=json.dumps(data))
self.assertEquals(self.client.post(self.test_endpoint), data)
def test_get_invalid_response_body(self): def test_get_invalid_response_body(self):
""" Verify that client raises a ClientError if the response body cannot be properly parsed. """ """ Verify that client raises a ClientError if the response body cannot be properly parsed. """
...@@ -71,7 +76,13 @@ class ClientTests(ClientTestCase): ...@@ -71,7 +76,13 @@ class ClientTests(ClientTestCase):
timeout = None timeout = None
headers = {'Accept': 'application/json'} headers = {'Accept': 'application/json'}
self.assertRaises(TimeoutError, self.client._request, self.test_endpoint, timeout=timeout) self.assertRaises(
TimeoutError,
self.client._request,
self.client.METHOD_GET,
self.test_endpoint,
timeout=timeout
)
msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, self.client.timeout) msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, self.client.timeout)
lc.check(('analyticsclient.client', 'ERROR', msg)) lc.check(('analyticsclient.client', 'ERROR', msg))
lc.clear() lc.clear()
...@@ -79,7 +90,13 @@ class ClientTests(ClientTestCase): ...@@ -79,7 +90,13 @@ class ClientTests(ClientTestCase):
mock_get.reset_mock() mock_get.reset_mock()
timeout = 10 timeout = 10
self.assertRaises(TimeoutError, self.client._request, self.test_endpoint, timeout=timeout) self.assertRaises(
TimeoutError,
self.client._request,
self.client.METHOD_GET,
self.test_endpoint,
timeout=timeout
)
mock_get.assert_called_once_with(url, headers=headers, timeout=timeout) mock_get.assert_called_once_with(url, headers=headers, timeout=timeout)
msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, timeout) msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, timeout)
lc.check(('analyticsclient.client', 'ERROR', msg)) lc.check(('analyticsclient.client', 'ERROR', msg))
...@@ -100,3 +117,11 @@ class ClientTests(ClientTestCase): ...@@ -100,3 +117,11 @@ class ClientTests(ClientTestCase):
response = self.client.get(self.test_endpoint, data_format=data_format.JSON) response = self.client.get(self.test_endpoint, data_format=data_format.JSON)
self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json') self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json')
self.assertDictEqual(response, {}) self.assertDictEqual(response, {})
def test_unsupported_method(self):
self.assertRaises(
ValueError,
self.client._request,
'PATCH',
self.test_endpoint
)
...@@ -9,6 +9,7 @@ class CourseSummariesTests(APIListTestCase, ClientTestCase): ...@@ -9,6 +9,7 @@ class CourseSummariesTests(APIListTestCase, ClientTestCase):
endpoint = 'course_summaries' endpoint = 'course_summaries'
id_field = 'course_ids' id_field = 'course_ids'
uses_post_method = True
@ddt.data( @ddt.data(
['123'], ['123'],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment