Commit 49efdabc by Kyle McCormick

Enhance CourseSummaries; add CourseTotals

Updates client to support pagination, filtering, searching, and sorting
now available in the edX Analytics Data API.

Bumps API version used in tests from v0 to v1.
parent fa05cfc1
from analyticsclient.constants import http_methods, data_formats
class BaseEndpoint(object):
"""Base class for endpoints that use a client object."""
def __init__(self, client):
"""
Initialize the API client.
Arguments:
client (analyticsclient.client.Client): The client to use to access remote resources.
"""
self.client = client
class PostableCourseIDsEndpoint(BaseEndpoint):
"""Base class for endpoints that pass in course IDs with either GET or POST."""
path = None # Override in subclass
max_num_ids_for_get = 10 # Optionally override in subclass
def do_request(self, course_ids, data, data_format=data_formats.JSON):
"""
Given course IDs, do the appropriate request method (GET or POST).
Arguments:
course_ids (list[str]): A list of course IDs to pass in
data (dict): Arguments for endpoint, sans course IDs
data_format (data_format)
Returns: dict
"""
data_with_ids = data.copy()
if course_ids:
data_with_ids.update(course_ids=course_ids)
method = (
http_methods.POST
if len(course_ids or []) > self.max_num_ids_for_get
else http_methods.GET
)
return self.client.request(method, self.path, data=data_with_ids, data_format=data_format)
...@@ -2,9 +2,10 @@ import logging ...@@ -2,9 +2,10 @@ import logging
import requests import requests
import requests.exceptions import requests.exceptions
from analyticsclient.constants import data_format as DF from analyticsclient.constants import http_methods, data_formats
from analyticsclient.course import Course from analyticsclient.course import Course
from analyticsclient.course_totals import CourseTotals
from analyticsclient.course_summaries import CourseSummaries from analyticsclient.course_summaries import CourseSummaries
from analyticsclient.exceptions import ClientError, InvalidRequestError, NotFoundError, TimeoutError from analyticsclient.exceptions import ClientError, InvalidRequestError, NotFoundError, TimeoutError
from analyticsclient.module import Module from analyticsclient.module import Module
...@@ -28,15 +29,12 @@ class Client(object): ...@@ -28,15 +29,12 @@ class Client(object):
DATE_FORMAT = '%Y-%m-%d' DATE_FORMAT = '%Y-%m-%d'
DATETIME_FORMAT = DATE_FORMAT + 'T%H%M%S' DATETIME_FORMAT = DATE_FORMAT + 'T%H%M%S'
METHOD_GET = 'GET'
METHOD_POST = 'POST'
def __init__(self, base_url, auth_token=None, timeout=0.25): def __init__(self, base_url, auth_token=None, timeout=0.25):
""" """
Initialize the client. Initialize the client.
Arguments: Arguments:
base_url (str): URL of the API server (e.g. http://analytics.edx.org/api/v0) base_url (str): URL of the API server (e.g. http://analytics.edx.org/api/v1)
auth_token (str): Authentication token auth_token (str): Authentication token
timeout (number): Maximum number of seconds during which all requests musts complete timeout (number): Maximum number of seconds during which all requests musts complete
""" """
...@@ -46,41 +44,32 @@ class Client(object): ...@@ -46,41 +44,32 @@ class Client(object):
self.status = Status(self) self.status = Status(self)
self.course_summaries = lambda: CourseSummaries(self) self.course_summaries = lambda: CourseSummaries(self)
self.course_totals = lambda: CourseTotals(self)
self.programs = lambda: Programs(self) self.programs = lambda: Programs(self)
self.courses = lambda course_id: Course(self, course_id) self.courses = lambda course_id: Course(self, course_id)
self.modules = lambda course_id, module_id: Module(self, course_id, module_id) self.modules = lambda course_id, module_id: Module(self, course_id, module_id)
def get(self, resource, timeout=None, data_format=DF.JSON): def get(self, *args, **kwargs):
""" """
Retrieve the data for a resource. Retrieve the data for a resource.
Arguments: Equivalent to `request(http_methods.GET, ...)`
resource (str): Path in the form of slash separated strings.
timeout (float): Continue to attempt to retrieve a resource for this many seconds before giving up and
raising an error.
data_format (str): Format in which data should be returned
Returns: API response data in specified data_format Returns: API response data in specified data_format
Raises: ClientError if the resource cannot be retrieved for any reason. Raises: ClientError if the resource cannot be retrieved for any reason.
""" """
return self._get_or_post( return self.request(http_methods.GET, *args, **kwargs)
self.METHOD_GET,
resource,
timeout=timeout,
data_format=data_format
)
def post(self, resource, post_data=None, timeout=None, data_format=DF.JSON): def request(self, method, resource, data=None, timeout=None, data_format=data_formats.JSON):
""" """
Retrieve the data for POST request. Retrieve the from an HTTP request.
Arguments: Arguments:
method (http_method): HTTP method. Only GET and POST are supported currenly.
resource (str): Path in the form of slash separated strings. resource (str): Path in the form of slash separated strings.
post_data (dict): Dictionary containing POST data. data (dict): Dictionary containing POST data.
timeout (float): Continue to attempt to retrieve a resource for this many seconds before giving up and timeout (float): Continue to attempt to retrieve a resource for this many seconds before giving up and
raising an error. raising an error.
data_format (str): Format in which data should be returned data_format (str): Format in which data should be returned
...@@ -90,14 +79,24 @@ class Client(object): ...@@ -90,14 +79,24 @@ class Client(object):
Raises: ClientError if the resource cannot be retrieved for any reason. Raises: ClientError if the resource cannot be retrieved for any reason.
""" """
return self._get_or_post( response = self._request(
self.METHOD_POST, method=method,
resource, resource=resource,
post_data=post_data, data=data,
timeout=timeout, timeout=timeout,
data_format=data_format data_format=data_format
) )
if data_format == data_formats.CSV:
return response.text
try:
return response.json()
except ValueError:
message = 'Unable to decode JSON response'
log.exception(message)
raise ClientError(message)
def has_resource(self, resource, timeout=None): def has_resource(self, resource, timeout=None):
""" """
Check if the server responds with a 200 OK status code when the resource is requested. Check if the server responds with a 200 OK status code when the resource is requested.
...@@ -114,37 +113,18 @@ class Client(object): ...@@ -114,37 +113,18 @@ class Client(object):
""" """
try: try:
self._request(self.METHOD_GET, resource, timeout=timeout) self._request(http_methods.GET, resource, timeout=timeout)
return True return True
except ClientError: except ClientError:
return False return False
def _get_or_post(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON):
response = self._request(
method,
resource,
post_data=post_data,
timeout=timeout,
data_format=data_format
)
if data_format == DF.CSV:
return response.text
try:
return response.json()
except ValueError:
message = 'Unable to decode JSON response'
log.exception(message)
raise ClientError(message)
# pylint: disable=no-member # pylint: disable=no-member
def _request(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON): def _request(self, method, resource, data=None, timeout=None, data_format=data_formats.JSON):
if timeout is None: if timeout is None:
timeout = self.timeout timeout = self.timeout
accept_format = 'application/json' accept_format = 'application/json'
if data_format == DF.CSV: if data_format == data_formats.CSV:
accept_format = 'text/csv' accept_format = 'text/csv'
headers = { headers = {
...@@ -157,14 +137,17 @@ class Client(object): ...@@ -157,14 +137,17 @@ class Client(object):
try: try:
uri = '{0}/{1}'.format(self.base_url, resource) uri = '{0}/{1}'.format(self.base_url, resource)
if method == self.METHOD_GET: if method == http_methods.GET:
response = requests.get(uri, headers=headers, timeout=timeout) params = self._data_to_get_params(data or {})
elif method == self.METHOD_POST: response = requests.get(uri, params=params, headers=headers, timeout=timeout)
response = requests.post(uri, data=(post_data or {}), headers=headers, timeout=timeout) elif method == http_methods.POST:
response = requests.post(uri, data=(data or {}), headers=headers, timeout=timeout)
else: else:
raise ValueError( raise ValueError(
'Invalid \'method\' argument: expected {0} or {1}, got {2}'.format( 'Invalid \'method\' argument: expected {0} or {1}, got {2}'.format(
self.METHOD_GET, self.METHOD_POST, method http_methods.GET,
http_methods.POST,
method,
) )
) )
...@@ -194,3 +177,14 @@ class Client(object): ...@@ -194,3 +177,14 @@ class Client(object):
message = 'Unable to retrieve resource' message = 'Unable to retrieve resource'
log.exception(message) log.exception(message)
raise ClientError('{0} "{1}"'.format(message, resource)) raise ClientError('{0} "{1}"'.format(message, resource))
@staticmethod
def _data_to_get_params(data):
return {
key: (
','.join(value)
if isinstance(value, list)
else str(value)
)
for key, value in data.iteritems()
}
"""Course activity types."""
ANY = 'any'
ATTEMPTED_PROBLEM = 'attempted_problem'
PLAYED_VIDEO = 'played_video'
POSTED_FORUM = 'posted_forum'
"""Course activity types."""
ANY = u'any'
ATTEMPTED_PROBLEM = u'attempted_problem'
PLAYED_VIDEO = u'played_video'
POSTED_FORUM = u'posted_forum'
"""Course demographics."""
BIRTH_YEAR = 'birth_year'
EDUCATION = 'education'
GENDER = 'gender'
LOCATION = 'location'
"""Course demographics."""
BIRTH_YEAR = u'birth_year'
EDUCATION = u'education'
GENDER = u'gender'
LOCATION = u'location'
NONE = 'none'
OTHER = 'other'
PRIMARY = 'primary'
JUNIOR_SECONDARY = 'junior_secondary'
SECONDARY = 'secondary'
ASSOCIATES = 'associates'
BACHELORS = 'bachelors'
MASTERS = 'masters'
DOCTORATE = 'doctorate'
NONE = u'none'
OTHER = u'other'
PRIMARY = u'primary'
JUNIOR_SECONDARY = u'junior_secondary'
SECONDARY = u'secondary'
ASSOCIATES = u'associates'
BACHELORS = u'bachelors'
MASTERS = u'masters'
DOCTORATE = u'doctorate'
FEMALE = 'female'
MALE = 'male'
OTHER = 'other'
UNKNOWN = 'unknown'
FEMALE = u'female'
MALE = u'male'
OTHER = u'other'
UNKNOWN = u'unknown'
GET = u'GET'
HEAD = u'HEAD'
POST = u'POST'
PUT = u'PUT'
DELETE = u'DELETE'
CONNECT = u'CONNECT'
OPTIONS = u'OPTIONS'
TRACE = u'TRACE'
PATCH = u'PATCH'
import urllib import urllib
import warnings import warnings
import analyticsclient.constants.activity_type as AT from analyticsclient.base import PostableCourseIDsEndpoint
import analyticsclient.constants.data_format as DF from analyticsclient.constants import activity_types, data_formats
from analyticsclient.exceptions import InvalidRequestError from analyticsclient.exceptions import InvalidRequestError
class Course(object): class Course(PostableCourseIDsEndpoint):
"""Course-related analytics.""" """Course-related analytics."""
def __init__(self, client, course_id): def __init__(self, client, course_id):
...@@ -18,10 +18,10 @@ class Course(object): ...@@ -18,10 +18,10 @@ class Course(object):
course_id (str): String identifying the course (e.g. edX/DemoX/Demo_Course) course_id (str): String identifying the course (e.g. edX/DemoX/Demo_Course)
""" """
self.client = client super(Course, self).__init__(client)
self.course_id = unicode(course_id) self.course_id = unicode(course_id)
def enrollment(self, demographic=None, start_date=None, end_date=None, data_format=DF.JSON): def enrollment(self, demographic=None, start_date=None, end_date=None, data_format=data_formats.JSON):
""" """
Get course enrollment data. Get course enrollment data.
...@@ -55,7 +55,7 @@ class Course(object): ...@@ -55,7 +55,7 @@ class Course(object):
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def activity(self, activity_type=AT.ANY, start_date=None, end_date=None, data_format=DF.JSON): def activity(self, activity_type=activity_types.ANY, start_date=None, end_date=None, data_format=data_formats.JSON):
""" """
Get the course student activity. Get the course student activity.
...@@ -82,7 +82,7 @@ class Course(object): ...@@ -82,7 +82,7 @@ class Course(object):
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def recent_activity(self, activity_type=AT.ANY, data_format=DF.JSON): def recent_activity(self, activity_type=activity_types.ANY, data_format=data_formats.JSON):
""" """
Get the recent course activity. Get the recent course activity.
...@@ -95,7 +95,7 @@ class Course(object): ...@@ -95,7 +95,7 @@ class Course(object):
path = 'courses/{0}/recent_activity/?activity_type={1}'.format(self.course_id, activity_type) path = 'courses/{0}/recent_activity/?activity_type={1}'.format(self.course_id, activity_type)
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def problems(self, data_format=DF.JSON): def problems(self, data_format=data_formats.JSON):
""" """
Get the problems for the course. Get the problems for the course.
...@@ -105,7 +105,7 @@ class Course(object): ...@@ -105,7 +105,7 @@ class Course(object):
path = 'courses/{0}/problems/'.format(self.course_id) path = 'courses/{0}/problems/'.format(self.course_id)
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def problems_and_tags(self, data_format=DF.JSON): def problems_and_tags(self, data_format=data_formats.JSON):
""" """
Get the problems for the course with assigned tags. Get the problems for the course with assigned tags.
...@@ -115,7 +115,7 @@ class Course(object): ...@@ -115,7 +115,7 @@ class Course(object):
path = 'courses/{0}/problems_and_tags/'.format(self.course_id) path = 'courses/{0}/problems_and_tags/'.format(self.course_id)
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def reports(self, report_name, data_format=DF.JSON): def reports(self, report_name, data_format=data_formats.JSON):
""" """
Get CSV download information for a particular report in the course. Get CSV download information for a particular report in the course.
...@@ -125,7 +125,7 @@ class Course(object): ...@@ -125,7 +125,7 @@ class Course(object):
path = 'courses/{0}/reports/{1}/'.format(self.course_id, report_name) path = 'courses/{0}/reports/{1}/'.format(self.course_id, report_name)
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def videos(self, data_format=DF.JSON): def videos(self, data_format=data_formats.JSON):
""" """
Get the videos for the course. Get the videos for the course.
......
import analyticsclient.constants.data_format as DF from analyticsclient.base import PostableCourseIDsEndpoint
from analyticsclient.constants import data_formats
class CourseSummaries(object): class CourseSummaries(PostableCourseIDsEndpoint):
"""Course summaries.""" """Course summaries."""
def __init__(self, client): path = 'course_summaries/'
"""
Initialize the CourseSummaries client.
Arguments:
client (analyticsclient.client.Client): The client to use to access remote resources.
"""
self.client = client
def course_summaries(self, course_ids=None, fields=None, exclude=None, programs=None, data_format=DF.JSON): def course_summaries(
self,
course_ids=None,
availability=None,
pacing_type=None,
program_ids=None,
text_search=None,
order_by=None,
sort_order=None,
page=None,
page_size=None,
request_all=False,
fields=None,
exclude=None,
data_format=data_formats.JSON,
):
""" """
Get list of summaries. Get list of summaries.
Arguments: For more detailed parameter and return type descriptions, see the
course_ids: Array of course IDs as strings to return. Default is to return all. edX Analytics Data API documentation.
fields: Array of fields to return. Default is to return all.
exclude: Array of fields to exclude from response. Default is to not exclude any fields.
programs: If included in the query parameters, will include the programs array in the response.
"""
post_data = {}
for param_name, data in zip(['course_ids', 'fields', 'exclude', 'programs'],
[course_ids, fields, exclude, programs]):
if data:
post_data[param_name] = data
path = 'course_summaries/' Arguments:
course_ids (list[str]): Course IDs to filter by.
availability (list[str]) Availabilities to filter by.
pacing_type (list[str]): Pacing types to filter by.
program_ids (list[str]): Course IDs of programs to filter by.
text_search (str): Sub-string to search for in course titles and IDs.
order_by (str): Summary field to sort by.
sort_order (str): Order of the sort.
page (int): Page number.
page_size (int): Size of page.
request_all (bool): Whether all summaries should be returned, or just a
single page. Overrides `page` and `page_size`.
fields (list[str]) Fields of course summaries to return in response.
exclude (list[str]) Fields of course summaries to NOT return in response.
data_format (str): Data format for response. Must be data_format.JSON or
data_format.CSV.
return self.client.post(path, post_data=post_data, data_format=data_format) Returns: dict
"""
raw_data = {
'availability': availability,
'pacing_type': pacing_type,
'program_ids': program_ids,
'text_search': text_search,
'order_by': order_by,
'sort_order': sort_order,
'page': page,
'page_size': page_size,
'fields': fields,
'exclude': exclude,
'all': request_all,
}
data = {
key: value
for key, value in raw_data.iteritems()
if value
}
return self.do_request(course_ids=course_ids, data=data, data_format=data_format)
from analyticsclient.base import PostableCourseIDsEndpoint
from analyticsclient.constants import data_formats
class CourseTotals(PostableCourseIDsEndpoint):
"""Course aggregate data."""
path = 'course_totals/'
def course_totals(self, course_ids=None, data_format=data_formats.JSON):
"""
Get aggregate data about courses.
For more detailed parameter and return type descriptions, see the
edX Analytics Data API documentation.
Arguments:
course_ids (list[str]): Course IDs to filter by.
data_format (str): Data format for response.
Must be data_format.JSON or data_format.CSV.
"""
return self.do_request(course_ids=course_ids, data={}, data_format=data_format)
import analyticsclient.constants.data_format as DF from analyticsclient.base import BaseEndpoint
from analyticsclient.constants import data_formats
class Module(object): class Module(BaseEndpoint):
"""Module related analytics data.""" """Module related analytics data."""
def __init__(self, client, course_id, module_id): def __init__(self, client, course_id, module_id):
...@@ -13,11 +14,11 @@ class Module(object): ...@@ -13,11 +14,11 @@ class Module(object):
course_id (str): String identifying the course course_id (str): String identifying the course
module_id (str): String identifying the module module_id (str): String identifying the module
""" """
self.client = client super(Module, self).__init__(client)
self.course_id = unicode(course_id) self.course_id = unicode(course_id)
self.module_id = unicode(module_id) self.module_id = unicode(module_id)
def answer_distribution(self, data_format=DF.JSON): def answer_distribution(self, data_format=data_formats.JSON):
""" """
Get answer distribution data for a module. Get answer distribution data for a module.
...@@ -28,7 +29,7 @@ class Module(object): ...@@ -28,7 +29,7 @@ class Module(object):
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def grade_distribution(self, data_format=DF.JSON): def grade_distribution(self, data_format=data_formats.JSON):
""" """
Get grade distribution data for a module. Get grade distribution data for a module.
...@@ -39,7 +40,7 @@ class Module(object): ...@@ -39,7 +40,7 @@ class Module(object):
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def sequential_open_distribution(self, data_format=DF.JSON): def sequential_open_distribution(self, data_format=data_formats.JSON):
""" """
Get open distribution data for a module. Get open distribution data for a module.
...@@ -50,7 +51,7 @@ class Module(object): ...@@ -50,7 +51,7 @@ class Module(object):
return self.client.get(path, data_format=data_format) return self.client.get(path, data_format=data_format)
def video_timeline(self, data_format=DF.JSON): def video_timeline(self, data_format=data_formats.JSON):
""" """
Get video segments/timeline for a module. Get video segments/timeline for a module.
......
import urllib import urllib
import analyticsclient.constants.data_format as DF from analyticsclient.base import BaseEndpoint
from analyticsclient.constants import data_formats
class Programs(object): class Programs(BaseEndpoint):
"""Programs client.""" """Programs client."""
def __init__(self, client): def programs(self, program_ids=None, fields=None, exclude=None, data_format=data_formats.JSON, **kwargs):
"""
Initialize the Programs client.
Arguments:
client (analyticsclient.client.Client): The client to use to access remote resources.
"""
self.client = client
def programs(self, program_ids=None, fields=None, exclude=None, data_format=DF.JSON):
""" """
Get list of programs metadata. Get list of programs metadata.
...@@ -28,7 +18,7 @@ class Programs(object): ...@@ -28,7 +18,7 @@ class Programs(object):
""" """
query_params = {} query_params = {}
for query_arg, data in zip(['program_ids', 'fields', 'exclude'], for query_arg, data in zip(['program_ids', 'fields', 'exclude'],
[program_ids, fields, exclude]): [program_ids, fields, exclude]) + kwargs.items():
if data: if data:
query_params[query_arg] = ','.join(data) query_params[query_arg] = ','.join(data)
......
from analyticsclient.base import BaseEndpoint
from analyticsclient.exceptions import ClientError from analyticsclient.exceptions import ClientError
class Status(object): class Status(BaseEndpoint):
"""API server status.""" """API server status."""
def __init__(self, client):
"""
Initialize the Status.
Arguments:
client (analyticsclient.client.Client): The client to use to access remote resources.
"""
self.client = client
@property @property
def alive(self): def alive(self):
""" """
......
...@@ -11,7 +11,7 @@ class ClientTestCase(TestCase): ...@@ -11,7 +11,7 @@ class ClientTestCase(TestCase):
def setUp(self): def setUp(self):
"""Configure Client.""" """Configure Client."""
self.api_url = 'http://localhost:9999/api/v0' self.api_url = 'http://localhost:9999/api/v1'
self.client = Client(self.api_url) self.client = Client(self.api_url)
def get_api_url(self, path): def get_api_url(self, path):
...@@ -28,17 +28,17 @@ class ClientTestCase(TestCase): ...@@ -28,17 +28,17 @@ class ClientTestCase(TestCase):
@ddt.ddt @ddt.ddt
class APIListTestCase(object): class APIWithIDsTestCase(object):
"""Base class for API list view tests.""" """Base class for tests for API endpoints that take lists of IDs."""
# Override in the subclass: # Override in the subclass:
endpoint = 'list' endpoint = 'endpoint'
id_field = 'id' id_field = 'id'
uses_post_method = False other_params = frozenset()
def setUp(self): def setUp(self):
"""Set up the test case.""" """Set up the test case."""
super(APIListTestCase, self).setUp() super(APIWithIDsTestCase, self).setUp()
self.base_uri = self.get_api_url('{}/'.format(self.endpoint)) self.base_uri = self.get_api_url('{}/'.format(self.endpoint))
self.client_class = getattr(self.client, self.endpoint)() self.client_class = getattr(self.client, self.endpoint)()
httpretty.enable() httpretty.enable()
...@@ -49,44 +49,76 @@ class APIListTestCase(object): ...@@ -49,44 +49,76 @@ class APIListTestCase(object):
def expected_query(self, **kwargs): def expected_query(self, **kwargs):
"""Pack the query arguments into expected format for http pretty.""" """Pack the query arguments into expected format for http pretty."""
query = {} return {
for field, data in kwargs.items(): field: (
if data is not None: [','.join(data)] if isinstance(data, list) else [str(data)]
query[field] = [','.join(data)] )
return query for field, data in kwargs.iteritems()
if data
}
@httpretty.activate @httpretty.activate
def kwarg_test(self, **kwargs): def verify_query_params(self, **kwargs):
"""Construct URL with given query parameters and check if it is what we expect.""" """Construct URL with given query parameters and check if it is what we expect."""
httpretty.reset() httpretty.reset()
if self.uses_post_method:
httpretty.register_uri(httpretty.POST, self.base_uri, body='{}') uri_template = '{uri}?'
getattr(self.client_class, self.endpoint)(**kwargs) for key in kwargs:
self.assertDictEqual(httpretty.last_request().parsed_body or {}, kwargs) uri_template += '%s={%s}' % (key, key)
else: uri = uri_template.format(uri=self.base_uri, **kwargs)
uri_template = '{uri}?'
for key in kwargs: httpretty.register_uri(httpretty.GET, uri, body='{}')
uri_template += '%s={%s}' % (key, key) getattr(self.client_class, self.endpoint)(**kwargs)
uri = uri_template.format(uri=self.base_uri, **kwargs) self.verify_last_querystring_equal(self.expected_query(**kwargs))
httpretty.register_uri(httpretty.GET, uri, body='{}')
getattr(self.client_class, self.endpoint)(**kwargs) def fill_in_empty_params_with_dummies(self, **kwargs):
self.verify_last_querystring_equal(self.expected_query(**kwargs)) """Fill in non-provided parameters with dummy values so they are tested."""
params = {param: '.' for param in self.other_params}
def test_all_items_url(self): params.update(kwargs)
"""Endpoint can be called without parameters.""" return params
httpretty.register_uri(
httpretty.POST if self.uses_post_method else httpretty.GET,
self.base_uri, body='{}'
)
getattr(self.client_class, self.endpoint)()
@ddt.data( @ddt.data(
[],
['edx/demo/course'], ['edx/demo/course'],
['edx/demo/course', 'another/demo/course'] ['edx/demo/course', 'another/demo/course'],
) )
def test_courses_ids(self, ids): def test_url_with_params(self, ids):
"""Endpoint can be called with IDs.""" """Endpoint can be called with parameters, including IDs."""
self.kwarg_test(**{self.id_field: ids}) params = self.fill_in_empty_params_with_dummies(**{self.id_field: ids})
self.verify_query_params(**params)
def test_url_without_params(self):
"""Endpoint can be called without parameters."""
httpretty.register_uri(httpretty.GET, self.base_uri, body='{}')
getattr(self.client_class, self.endpoint)()
class APIWithPostableIDsTestCase(APIWithIDsTestCase):
"""Base class for tests for API endpoints that can POST a list of course IDs."""
@httpretty.activate
def verify_post_data(self, **kwargs):
"""Construct POST request with parameters and check if it is what we expect."""
httpretty.reset()
httpretty.register_uri(httpretty.POST, self.base_uri, body='{}')
getattr(self.client_class, self.endpoint)(**kwargs)
expected_body = kwargs.copy()
for key, val in expected_body.iteritems():
if not isinstance(val, list):
expected_body[key] = [val]
actual_body = httpretty.last_request().parsed_body
self.assertDictEqual(actual_body or {}, expected_body)
def test_request_with_many_ids(self):
"""Endpoint can be called with a large number of ID parameters."""
params = self.fill_in_empty_params_with_dummies(**{self.id_field: ['id'] * 10000})
self.verify_post_data(**params)
@ddt.ddt
class APIListTestCase(APIWithIDsTestCase):
"""Base class for API list view tests."""
@ddt.data( @ddt.data(
['course_id'], ['course_id'],
...@@ -94,7 +126,7 @@ class APIListTestCase(object): ...@@ -94,7 +126,7 @@ class APIListTestCase(object):
) )
def test_fields(self, fields): def test_fields(self, fields):
"""Endpoint can be called with fields.""" """Endpoint can be called with fields."""
self.kwarg_test(fields=fields) self.verify_query_params(fields=fields)
@ddt.data( @ddt.data(
['course_id'], ['course_id'],
...@@ -102,7 +134,7 @@ class APIListTestCase(object): ...@@ -102,7 +134,7 @@ class APIListTestCase(object):
) )
def test_exclude(self, exclude): def test_exclude(self, exclude):
"""Endpoint can be called with exclude.""" """Endpoint can be called with exclude."""
self.kwarg_test(exclude=exclude) self.verify_query_params(exclude=exclude)
@ddt.data( @ddt.data(
(['edx/demo/course'], ['course_id'], ['enrollment_modes']), (['edx/demo/course'], ['course_id'], ['enrollment_modes']),
...@@ -110,6 +142,9 @@ class APIListTestCase(object): ...@@ -110,6 +142,9 @@ class APIListTestCase(object):
['created', 'pacing_type']) ['created', 'pacing_type'])
) )
@ddt.unpack @ddt.unpack
def test_all_parameters(self, ids, fields, exclude): def test_all_list_parameters(self, ids, fields, exclude):
"""Endpoint can be called with all parameters.""" """Endpoint can be called with IDs, fields, and exlude parameters."""
self.kwarg_test(**{self.id_field: ids, 'fields': fields, 'exclude': exclude}) params = self.fill_in_empty_params_with_dummies(
**{self.id_field: ids, 'fields': fields, 'exclude': exclude}
)
self.verify_query_params(**params)
...@@ -5,7 +5,7 @@ import mock ...@@ -5,7 +5,7 @@ import mock
import requests.exceptions import requests.exceptions
from testfixtures import log_capture from testfixtures import log_capture
from analyticsclient.constants import data_format from analyticsclient.constants import data_formats, http_methods
from analyticsclient.client import Client from analyticsclient.client import Client
from analyticsclient.exceptions import ClientError, TimeoutError from analyticsclient.exceptions import ClientError, TimeoutError
from analyticsclient.tests import ClientTestCase from analyticsclient.tests import ClientTestCase
...@@ -49,7 +49,7 @@ class ClientTests(ClientTestCase): ...@@ -49,7 +49,7 @@ class ClientTests(ClientTestCase):
def test_post(self): def test_post(self):
data = {'foo': 'bar'} data = {'foo': 'bar'}
httpretty.register_uri(httpretty.POST, self.test_url, body=json.dumps(data)) httpretty.register_uri(httpretty.POST, self.test_url, body=json.dumps(data))
self.assertEquals(self.client.post(self.test_endpoint), data) self.assertEquals(self.client.request(http_methods.POST, self.test_endpoint), data)
def test_get_invalid_response_body(self): def test_get_invalid_response_body(self):
""" Verify that client raises a ClientError if the response body cannot be properly parsed. """ """ Verify that client raises a ClientError if the response body cannot be properly parsed. """
...@@ -79,25 +79,25 @@ class ClientTests(ClientTestCase): ...@@ -79,25 +79,25 @@ class ClientTests(ClientTestCase):
self.assertRaises( self.assertRaises(
TimeoutError, TimeoutError,
self.client._request, self.client._request,
self.client.METHOD_GET, http_methods.GET,
self.test_endpoint, self.test_endpoint,
timeout=timeout timeout=timeout
) )
msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, self.client.timeout) msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, self.client.timeout)
lc.check(('analyticsclient.client', 'ERROR', msg)) lc.check(('analyticsclient.client', 'ERROR', msg))
lc.clear() lc.clear()
mock_get.assert_called_once_with(url, headers=headers, timeout=self.client.timeout) mock_get.assert_called_once_with(url, headers=headers, timeout=self.client.timeout, params={})
mock_get.reset_mock() mock_get.reset_mock()
timeout = 10 timeout = 10
self.assertRaises( self.assertRaises(
TimeoutError, TimeoutError,
self.client._request, self.client._request,
self.client.METHOD_GET, http_methods.GET,
self.test_endpoint, self.test_endpoint,
timeout=timeout timeout=timeout
) )
mock_get.assert_called_once_with(url, headers=headers, timeout=timeout) mock_get.assert_called_once_with(url, headers=headers, timeout=timeout, params={})
msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, timeout) msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, timeout)
lc.check(('analyticsclient.client', 'ERROR', msg)) lc.check(('analyticsclient.client', 'ERROR', msg))
...@@ -109,12 +109,12 @@ class ClientTests(ClientTestCase): ...@@ -109,12 +109,12 @@ class ClientTests(ClientTestCase):
self.assertDictEqual(response, {}) self.assertDictEqual(response, {})
httpretty.register_uri(httpretty.GET, self.test_url, body='not-json') httpretty.register_uri(httpretty.GET, self.test_url, body='not-json')
response = self.client.get(self.test_endpoint, data_format=data_format.CSV) response = self.client.get(self.test_endpoint, data_format=data_formats.CSV)
self.assertEquals(httpretty.last_request().headers['Accept'], 'text/csv') self.assertEquals(httpretty.last_request().headers['Accept'], 'text/csv')
self.assertEqual(response, 'not-json') self.assertEqual(response, 'not-json')
httpretty.register_uri(httpretty.GET, self.test_url, body='{}') httpretty.register_uri(httpretty.GET, self.test_url, body='{}')
response = self.client.get(self.test_endpoint, data_format=data_format.JSON) response = self.client.get(self.test_endpoint, data_format=data_formats.JSON)
self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json') self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json')
self.assertDictEqual(response, {}) self.assertDictEqual(response, {})
...@@ -122,6 +122,6 @@ class ClientTests(ClientTestCase): ...@@ -122,6 +122,6 @@ class ClientTests(ClientTestCase):
self.assertRaises( self.assertRaises(
ValueError, ValueError,
self.client._request, self.client._request,
'PATCH', http_methods.PATCH,
self.test_endpoint self.test_endpoint
) )
...@@ -3,9 +3,7 @@ import re ...@@ -3,9 +3,7 @@ import re
import httpretty import httpretty
from analyticsclient.constants import activity_type as at from analyticsclient.constants import activity_types, data_formats, demographics
from analyticsclient.constants import data_format
from analyticsclient.constants import demographic as demo
from analyticsclient.exceptions import NotFoundError, InvalidRequestError from analyticsclient.exceptions import NotFoundError, InvalidRequestError
from analyticsclient.tests import ClientTestCase from analyticsclient.tests import ClientTestCase
...@@ -82,10 +80,10 @@ class CoursesTests(ClientTestCase): ...@@ -82,10 +80,10 @@ class CoursesTests(ClientTestCase):
self.assertDictEqual(body, self.course.recent_activity(activity_type)) self.assertDictEqual(body, self.course.recent_activity(activity_type))
def test_recent_activity(self): def test_recent_activity(self):
self.assertRecentActivityResponseData(self.course, at.ANY) self.assertRecentActivityResponseData(self.course, activity_types.ANY)
self.assertRecentActivityResponseData(self.course, at.ATTEMPTED_PROBLEM) self.assertRecentActivityResponseData(self.course, activity_types.ATTEMPTED_PROBLEM)
self.assertRecentActivityResponseData(self.course, at.PLAYED_VIDEO) self.assertRecentActivityResponseData(self.course, activity_types.PLAYED_VIDEO)
self.assertRecentActivityResponseData(self.course, at.POSTED_FORUM) self.assertRecentActivityResponseData(self.course, activity_types.POSTED_FORUM)
def test_not_found(self): def test_not_found(self):
""" Course calls should raise a NotFoundError when provided with an invalid course. """ """ Course calls should raise a NotFoundError when provided with an invalid course. """
...@@ -96,8 +94,8 @@ class CoursesTests(ClientTestCase): ...@@ -96,8 +94,8 @@ class CoursesTests(ClientTestCase):
httpretty.register_uri(httpretty.GET, uri, status=404) httpretty.register_uri(httpretty.GET, uri, status=404)
course = self.client.courses(course_id) course = self.client.courses(course_id)
self.assertRaises(NotFoundError, course.recent_activity, at.ANY) self.assertRaises(NotFoundError, course.recent_activity, activity_types.ANY)
self.assertRaises(NotFoundError, course.enrollment, demo.EDUCATION) self.assertRaises(NotFoundError, course.enrollment, demographics.EDUCATION)
def test_invalid_parameter(self): def test_invalid_parameter(self):
""" Course calls should raise a InvalidRequestError when parameters are invalid. """ """ Course calls should raise a InvalidRequestError when parameters are invalid. """
...@@ -111,17 +109,17 @@ class CoursesTests(ClientTestCase): ...@@ -111,17 +109,17 @@ class CoursesTests(ClientTestCase):
def test_enrollment(self): def test_enrollment(self):
self.assertCorrectEnrollmentUrl(self.course, None) self.assertCorrectEnrollmentUrl(self.course, None)
self.assertCorrectEnrollmentUrl(self.course, demo.BIRTH_YEAR) self.assertCorrectEnrollmentUrl(self.course, demographics.BIRTH_YEAR)
self.assertCorrectEnrollmentUrl(self.course, demo.EDUCATION) self.assertCorrectEnrollmentUrl(self.course, demographics.EDUCATION)
self.assertCorrectEnrollmentUrl(self.course, demo.GENDER) self.assertCorrectEnrollmentUrl(self.course, demographics.GENDER)
self.assertCorrectEnrollmentUrl(self.course, demo.LOCATION) self.assertCorrectEnrollmentUrl(self.course, demographics.LOCATION)
def test_activity(self): def test_activity(self):
self.assertRaises(InvalidRequestError, self.assertCorrectActivityUrl, self.course, None) self.assertRaises(InvalidRequestError, self.assertCorrectActivityUrl, self.course, None)
self.assertCorrectActivityUrl(self.course, at.ANY) self.assertCorrectActivityUrl(self.course, activity_types.ANY)
self.assertCorrectActivityUrl(self.course, at.ATTEMPTED_PROBLEM) self.assertCorrectActivityUrl(self.course, activity_types.ATTEMPTED_PROBLEM)
self.assertCorrectActivityUrl(self.course, at.PLAYED_VIDEO) self.assertCorrectActivityUrl(self.course, activity_types.PLAYED_VIDEO)
self.assertCorrectActivityUrl(self.course, at.POSTED_FORUM) self.assertCorrectActivityUrl(self.course, activity_types.POSTED_FORUM)
def test_enrollment_data_format(self): def test_enrollment_data_format(self):
uri = self.get_api_url('courses/{0}/enrollment/'.format(self.course.course_id)) uri = self.get_api_url('courses/{0}/enrollment/'.format(self.course.course_id))
...@@ -132,7 +130,7 @@ class CoursesTests(ClientTestCase): ...@@ -132,7 +130,7 @@ class CoursesTests(ClientTestCase):
self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json') self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json')
httpretty.register_uri(httpretty.GET, uri, body='not-json') httpretty.register_uri(httpretty.GET, uri, body='not-json')
self.course.enrollment(data_format=data_format.CSV) self.course.enrollment(data_format=data_formats.CSV)
self.assertEquals(httpretty.last_request().headers['Accept'], 'text/csv') self.assertEquals(httpretty.last_request().headers['Accept'], 'text/csv')
@httpretty.activate @httpretty.activate
......
# pylint: disable=arguments-differ
import ddt import ddt
from analyticsclient.tests import ClientTestCase, APIListTestCase from analyticsclient.tests import (
APIListTestCase,
APIWithPostableIDsTestCase,
ClientTestCase
)
@ddt.ddt @ddt.ddt
class CourseSummariesTests(APIListTestCase, ClientTestCase): class CourseSummariesTests(APIListTestCase, APIWithPostableIDsTestCase, ClientTestCase):
endpoint = 'course_summaries' endpoint = 'course_summaries'
id_field = 'course_ids' id_field = 'course_ids'
uses_post_method = True
@ddt.data( _LIST_PARAMS = frozenset([
['123'], 'course_ids',
['123', '456'] 'availability',
) 'pacing_type',
def test_programs(self, programs): 'program_ids',
"""Course summaries can be called with programs.""" 'fields',
self.kwarg_test(programs=programs) 'exclude',
])
_STRING_PARAMS = frozenset([
'text_search',
'order_by',
'sort_order',
])
_INT_PARAMS = frozenset([
'page',
'page_size',
])
_ALL_PARAMS = _LIST_PARAMS | _STRING_PARAMS | _INT_PARAMS
other_params = _ALL_PARAMS
# Test URL encoding (note: '+' is not handled right by httpretty, but it works in practice)
_TEST_STRING = 'Aa1_-:/* '
@ddt.data( @ddt.data(
(['edx/demo/course'], ['course_id'], ['enrollment_modes'], ['123']), (_LIST_PARAMS, ['a', 'b', 'c']),
(['edx/demo/course', 'another/demo/course'], ['course_id', 'enrollment_modes'], (_LIST_PARAMS, [_TEST_STRING]),
['created', 'pacing_type'], ['123', '456']) (_LIST_PARAMS, []),
(_STRING_PARAMS, _TEST_STRING),
(_STRING_PARAMS, ''),
(_INT_PARAMS, 1),
(_INT_PARAMS, 0),
(frozenset(), None),
) )
@ddt.unpack @ddt.unpack
def test_all_parameters(self, course_ids, fields, exclude, programs): def test_all_parameters(self, param_names, param_value):
"""Course summaries can be called with all parameters including programs.""" """Course summaries can be called with all parameters."""
self.kwarg_test(course_ids=course_ids, fields=fields, exclude=exclude, programs=programs) params = {param_name: None for param_name in self._ALL_PARAMS}
params.update({param_name: param_value for param_name in param_names})
self.verify_query_params(**params)
from analyticsclient.tests import ClientTestCase, APIWithPostableIDsTestCase
class CourseTotalsTests(APIWithPostableIDsTestCase, ClientTestCase):
endpoint = 'course_totals'
id_field = 'course_ids'
other_params = frozenset()
from unittest import TestCase from unittest import TestCase
from analyticsclient.constants import activity_type, demographic, education_level, gender, enrollment_modes from analyticsclient.constants import activity_types, demographics, education_levels, genders, enrollment_modes
class HelperTests(TestCase): class HelperTests(TestCase):
...@@ -8,32 +8,32 @@ class HelperTests(TestCase): ...@@ -8,32 +8,32 @@ class HelperTests(TestCase):
""" """
def test_activity_types(self): def test_activity_types(self):
self.assertEqual('any', activity_type.ANY) self.assertEqual('any', activity_types.ANY)
self.assertEqual('attempted_problem', activity_type.ATTEMPTED_PROBLEM) self.assertEqual('attempted_problem', activity_types.ATTEMPTED_PROBLEM)
self.assertEqual('played_video', activity_type.PLAYED_VIDEO) self.assertEqual('played_video', activity_types.PLAYED_VIDEO)
self.assertEqual('posted_forum', activity_type.POSTED_FORUM) self.assertEqual('posted_forum', activity_types.POSTED_FORUM)
def test_demographics(self): def test_demographics(self):
self.assertEqual('birth_year', demographic.BIRTH_YEAR) self.assertEqual('birth_year', demographics.BIRTH_YEAR)
self.assertEqual('education', demographic.EDUCATION) self.assertEqual('education', demographics.EDUCATION)
self.assertEqual('gender', demographic.GENDER) self.assertEqual('gender', demographics.GENDER)
def test_education_levels(self): def test_education_levels(self):
self.assertEqual('none', education_level.NONE) self.assertEqual('none', education_levels.NONE)
self.assertEqual('other', education_level.OTHER) self.assertEqual('other', education_levels.OTHER)
self.assertEqual('primary', education_level.PRIMARY) self.assertEqual('primary', education_levels.PRIMARY)
self.assertEqual('junior_secondary', education_level.JUNIOR_SECONDARY) self.assertEqual('junior_secondary', education_levels.JUNIOR_SECONDARY)
self.assertEqual('secondary', education_level.SECONDARY) self.assertEqual('secondary', education_levels.SECONDARY)
self.assertEqual('associates', education_level.ASSOCIATES) self.assertEqual('associates', education_levels.ASSOCIATES)
self.assertEqual('bachelors', education_level.BACHELORS) self.assertEqual('bachelors', education_levels.BACHELORS)
self.assertEqual('masters', education_level.MASTERS) self.assertEqual('masters', education_levels.MASTERS)
self.assertEqual('doctorate', education_level.DOCTORATE) self.assertEqual('doctorate', education_levels.DOCTORATE)
def test_genders(self): def test_genders(self):
self.assertEqual('female', gender.FEMALE) self.assertEqual('female', genders.FEMALE)
self.assertEqual('male', gender.MALE) self.assertEqual('male', genders.MALE)
self.assertEqual('other', gender.OTHER) self.assertEqual('other', genders.OTHER)
self.assertEqual('unknown', gender.UNKNOWN) self.assertEqual('unknown', genders.UNKNOWN)
def test_enrollment_modes(self): def test_enrollment_modes(self):
self.assertEqual('audit', enrollment_modes.AUDIT) self.assertEqual('audit', enrollment_modes.AUDIT)
......
...@@ -8,3 +8,4 @@ class ProgramsTests(APIListTestCase, ClientTestCase): ...@@ -8,3 +8,4 @@ class ProgramsTests(APIListTestCase, ClientTestCase):
endpoint = 'programs' endpoint = 'programs'
id_field = 'program_ids' id_field = 'program_ids'
other_params = frozenset()
...@@ -2,7 +2,7 @@ from distutils.core import setup ...@@ -2,7 +2,7 @@ from distutils.core import setup
setup( setup(
name='edx-analytics-data-api-client', name='edx-analytics-data-api-client',
version='0.12.0', version='0.13.0',
packages=['analyticsclient', 'analyticsclient.constants'], packages=['analyticsclient', 'analyticsclient.constants'],
url='https://github.com/edx/edx-analytics-data-api-client', url='https://github.com/edx/edx-analytics-data-api-client',
description='Client used to access edX analytics data warehouse', description='Client used to access edX analytics data warehouse',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment