Commit 228912dd by Clinton Blackburn Committed by GitHub

Restored parts of original AbstractDataLoader architecture (#194)

- The api_url is once again passed to the constructor.
- Restored the api_client property
- Simplified instantiation of data loaders in refresh_course_metadata

ECOM-5099
parent 26c46e7c
......@@ -6,6 +6,7 @@ from urllib.parse import urljoin
import html2text
from dateutil.parser import parse
from django.utils.functional import cached_property
from edx_rest_api_client.client import EdxRestApiClient
from opaque_keys.edx.keys import CourseKey
......@@ -23,6 +24,7 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
""" Base class for all data loaders.
Attributes:
api_url (str): URL of the API from which data is loaded
partner (Partner): Partner which owns the data for this data loader
access_token (str): OAuth2 access token
PAGE_SIZE (int): Number of items to load per API call
......@@ -31,12 +33,13 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
PAGE_SIZE = 50
SUPPORTED_TOKEN_TYPES = ('bearer', 'jwt',)
def __init__(self, partner, access_token, token_type):
def __init__(self, partner, api_url, access_token, token_type):
"""
Arguments:
partner (Partner): Partner which owns the APIs and data being loaded
api_url (str): URL of the API from which data is loaded
access_token (str): OAuth2 access token
token_type (str): The type of access token passed in (e.g. Bearer, JWT)
partner (Partner): The Partner which owns the APIs and data being loaded
"""
token_type = token_type.lower()
......@@ -46,8 +49,10 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
self.access_token = access_token
self.token_type = token_type
self.partner = partner
self.api_url = api_url
def get_api_client(self, api_url):
@cached_property
def api_client(self):
"""
Returns an authenticated API client ready to call the API from which data is loaded.
......@@ -61,7 +66,7 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
else:
kwargs['oauth_access_token'] = self.access_token
return EdxRestApiClient(api_url, **kwargs)
return EdxRestApiClient(self.api_url, **kwargs)
@abc.abstractmethod
def ingest(self): # pragma: no cover
......@@ -125,14 +130,13 @@ class OrganizationsApiDataLoader(AbstractDataLoader):
def ingest(self):
api_url = self.partner.organizations_api_url
client = self.get_api_client(api_url)
count = None
page = 1
logger.info('Refreshing Organizations from %s...', api_url)
while page:
response = client.organizations().get(page=page, page_size=self.PAGE_SIZE)
response = self.api_client.organizations().get(page=page, page_size=self.PAGE_SIZE)
count = response['count']
results = response['results']
logger.info('Retrieved %d organizations...', len(results))
......@@ -169,14 +173,13 @@ class CoursesApiDataLoader(AbstractDataLoader):
def ingest(self):
api_url = self.partner.courses_api_url
client = self.get_api_client(api_url)
count = None
page = 1
logger.info('Refreshing Courses and CourseRuns from %s...', api_url)
while page:
response = client.courses().get(page=page, page_size=self.PAGE_SIZE)
response = self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE)
count = response['pagination']['count']
results = response['results']
logger.info('Retrieved %d course runs...', len(results))
......@@ -276,9 +279,8 @@ class DrupalApiDataLoader(AbstractDataLoader):
def ingest(self):
api_url = self.partner.marketing_site_api_url
client = self.get_api_client(api_url)
logger.info('Refreshing Courses and CourseRuns from %s...', api_url)
response = client.courses.get()
response = self.api_client.courses.get()
data = response['items']
logger.info('Retrieved %d course runs...', len(data))
......@@ -422,14 +424,13 @@ class EcommerceApiDataLoader(AbstractDataLoader):
def ingest(self):
api_url = self.partner.ecommerce_api_url
client = self.get_api_client(api_url)
count = None
page = 1
logger.info('Refreshing course seats from %s...', api_url)
while page:
response = client.courses().get(page=page, page_size=self.PAGE_SIZE, include_products=True)
response = self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, include_products=True)
count = response['count']
results = response['results']
logger.info('Retrieved %d course seats...', len(results))
......@@ -509,14 +510,13 @@ class ProgramsApiDataLoader(AbstractDataLoader):
def ingest(self):
api_url = self.partner.programs_api_url
client = self.get_api_client(api_url)
count = None
page = 1
logger.info('Refreshing programs from %s...', api_url)
while page:
response = client.programs.get(page=page, page_size=self.PAGE_SIZE)
response = self.api_client.programs.get(page=page, page_size=self.PAGE_SIZE)
count = response['count']
results = response['results']
logger.info('Retrieved %d programs...', len(results))
......
......@@ -74,22 +74,17 @@ class Command(BaseCommand):
logger.exception('No access token provided or acquired through client_credential flow.')
raise
loaders = []
if partner.organizations_api_url:
loaders.append(OrganizationsApiDataLoader)
if partner.courses_api_url:
loaders.append(CoursesApiDataLoader)
if partner.ecommerce_api_url:
loaders.append(EcommerceApiDataLoader)
if partner.marketing_site_api_url:
loaders.append(DrupalApiDataLoader)
if partner.programs_api_url:
loaders.append(ProgramsApiDataLoader)
if loaders:
for loader_class in loaders:
data_loaders = (
(partner.organizations_api_url, OrganizationsApiDataLoader,),
(partner.courses_api_url, CoursesApiDataLoader,),
(partner.ecommerce_api_url, EcommerceApiDataLoader,),
(partner.marketing_site_api_url, DrupalApiDataLoader,),
(partner.programs_api_url, ProgramsApiDataLoader,),
)
for api_url, loader_class in data_loaders:
if api_url:
try:
loader_class(partner, access_token, token_type).ingest()
loader_class(partner, api_url, access_token, token_type).ingest()
except Exception: # pylint: disable=broad-except
logger.exception('%s failed!', loader_class.__name__)
import json
import mock
import responses
from django.core.management import call_command, CommandError
from django.test import TestCase
from course_discovery.apps.core.tests.factories import PartnerFactory
from course_discovery.apps.core.tests.utils import mock_api_callback
from course_discovery.apps.course_metadata.data_loaders import (
CoursesApiDataLoader, DrupalApiDataLoader, OrganizationsApiDataLoader, EcommerceApiDataLoader, ProgramsApiDataLoader
)
from course_discovery.apps.course_metadata.models import Course, CourseRun, Organization, Program
from course_discovery.apps.course_metadata.tests import mock_data
......@@ -19,20 +23,23 @@ class RefreshCourseMetadataCommandTests(TestCase):
self.partner = PartnerFactory()
self.mock_access_token_api()
def mock_apis(self):
self.mock_organizations_api()
self.mock_lms_courses_api()
self.mock_ecommerce_courses_api()
self.mock_marketing_courses_api()
self.mock_programs_api()
def mock_access_token_api(self):
def mock_access_token_api(self, requests_mock=None):
body = {
'access_token': ACCESS_TOKEN,
'expires_in': 30
}
requests_mock = requests_mock or responses
url = self.partner.oidc_url_root.strip('/') + '/access_token'
responses.add_callback(
requests_mock.add_callback(
responses.POST,
url,
callback=mock_api_callback(url, body, results_key=False),
......@@ -83,7 +90,7 @@ class RefreshCourseMetadataCommandTests(TestCase):
self.partner.marketing_site_api_url + 'courses/',
body=json.dumps(body),
status=200,
content_type='application/json'
content_type=JSON
)
return body['items']
......@@ -101,12 +108,10 @@ class RefreshCourseMetadataCommandTests(TestCase):
@responses.activate
def test_refresh_course_metadata(self):
""" Verify the refresh_course_metadata management command creates new objects. """
self.mock_apis()
call_command('refresh_course_metadata')
organizations = Organization.objects.all()
for organization in organizations:
print(organization.key)
self.assertEqual(organizations.count(), 3)
for organization in organizations:
......@@ -134,6 +139,7 @@ class RefreshCourseMetadataCommandTests(TestCase):
@responses.activate
def test_refresh_course_metadata_with_invalid_partner_code(self):
""" Verify an error is raised if an invalid partner code is passed on the command line. """
self.mock_apis()
with self.assertRaises(CommandError):
command_args = ['--partner_code=invalid']
call_command('refresh_course_metadata', *command_args)
......@@ -141,6 +147,21 @@ class RefreshCourseMetadataCommandTests(TestCase):
@responses.activate
def test_refresh_course_metadata_with_no_token_type(self):
""" Verify an error is raised if an access token is passed in without a token type. """
self.mock_apis()
with self.assertRaises(CommandError):
command_args = ['--access_token=test-access-token']
call_command('refresh_course_metadata', *command_args)
def test_refresh_course_metadata_with_loader_exception(self):
""" Verify execution continues if an individual data loader fails. """
with responses.RequestsMock() as rsps:
self.mock_access_token_api(rsps)
logger_target = 'course_discovery.apps.course_metadata.management.commands.refresh_course_metadata.logger'
with mock.patch(logger_target) as mock_logger:
call_command('refresh_course_metadata')
loader_classes = (OrganizationsApiDataLoader, CoursesApiDataLoader, EcommerceApiDataLoader,
DrupalApiDataLoader, ProgramsApiDataLoader)
expected_calls = [mock.call('%s failed!', loader_class.__name__) for loader_class in loader_classes]
mock_logger.exception.assert_has_calls(expected_calls)
......@@ -12,14 +12,14 @@ from edx_rest_api_client.client import EdxRestApiClient
from opaque_keys.edx.keys import CourseKey
from pytz import UTC
from course_discovery.apps.core.tests.factories import PartnerFactory
from course_discovery.apps.core.tests.utils import mock_api_callback
from course_discovery.apps.course_metadata.data_loaders import (
OrganizationsApiDataLoader, CoursesApiDataLoader, DrupalApiDataLoader, EcommerceApiDataLoader, AbstractDataLoader,
ProgramsApiDataLoader)
ProgramsApiDataLoader
)
from course_discovery.apps.course_metadata.models import (
Course, CourseOrganization, CourseRun, Image, LanguageTag, Organization, Person, Seat, Subject,
Program)
Course, CourseOrganization, CourseRun, Image, LanguageTag, Organization, Person, Seat, Subject, Program
)
from course_discovery.apps.course_metadata.tests import mock_data
from course_discovery.apps.course_metadata.tests.factories import (
CourseRunFactory, SeatFactory, ImageFactory, PartnerFactory, PersonFactory, VideoFactory
......@@ -75,7 +75,11 @@ class DataLoaderTestMixin(object):
def setUp(self):
super(DataLoaderTestMixin, self).setUp()
self.partner = PartnerFactory()
self.loader = self.loader_class(self.partner, ACCESS_TOKEN, ACCESS_TOKEN_TYPE)
self.loader = self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, ACCESS_TOKEN_TYPE)
@property
def api_url(self): # pragma: no cover
raise NotImplementedError
def assert_api_called(self, expected_num_calls, check_auth=True):
""" Asserts the API was called with the correct number of calls, and the appropriate Authorization header. """
......@@ -92,17 +96,17 @@ class DataLoaderTestMixin(object):
def test_init_with_unsupported_token_type(self):
""" Verify the constructor raises an error if an unsupported token type is passed in. """
with self.assertRaises(ValueError):
self.loader_class(self.partner, ACCESS_TOKEN, 'not-supported')
self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, 'not-supported')
@ddt.unpack
@ddt.data(
('Bearer', BearerAuth),
('JWT', SuppliedJwtAuth),
)
def test_get_api_client(self, token_type, expected_auth_class):
def test_api_client(self, token_type, expected_auth_class):
""" Verify the property returns an API client with the correct authentication. """
loader = self.loader_class(self.partner, ACCESS_TOKEN, token_type)
client = loader.get_api_client(self.partner.programs_api_url)
loader = self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, token_type)
client = loader.api_client
self.assertIsInstance(client, EdxRestApiClient)
# NOTE (CCB): My initial preference was to mock the constructor and ensure the correct auth arguments
# were passed. However, that seems nearly impossible. This is the next best alternative. It is brittle, and
......@@ -114,9 +118,13 @@ class DataLoaderTestMixin(object):
class OrganizationsApiDataLoaderTests(DataLoaderTestMixin, TestCase):
loader_class = OrganizationsApiDataLoader
@property
def api_url(self):
return self.partner.organizations_api_url
def mock_api(self):
bodies = mock_data.ORGANIZATIONS_API_BODIES
url = self.partner.organizations_api_url + 'organizations/'
url = self.api_url + 'organizations/'
responses.add_callback(
responses.GET,
url,
......@@ -164,9 +172,13 @@ class OrganizationsApiDataLoaderTests(DataLoaderTestMixin, TestCase):
class CoursesApiDataLoaderTests(DataLoaderTestMixin, TestCase):
loader_class = CoursesApiDataLoader
@property
def api_url(self):
return self.partner.courses_api_url
def mock_api(self):
bodies = mock_data.COURSES_API_BODIES
url = self.partner.courses_api_url + 'courses/'
url = self.api_url + 'courses/'
responses.add_callback(
responses.GET,
url,
......@@ -299,6 +311,10 @@ class CoursesApiDataLoaderTests(DataLoaderTestMixin, TestCase):
class DrupalApiDataLoaderTests(DataLoaderTestMixin, TestCase):
loader_class = DrupalApiDataLoader
@property
def api_url(self):
return self.partner.marketing_site_api_url
def setUp(self):
super(DrupalApiDataLoaderTests, self).setUp()
for course_dict in mock_data.EXISTING_COURSE_AND_RUN_DATA:
......@@ -328,7 +344,7 @@ class DrupalApiDataLoaderTests(DataLoaderTestMixin, TestCase):
body = mock_data.MARKETING_API_BODY
responses.add(
responses.GET,
self.partner.marketing_site_api_url + 'courses/',
self.api_url + 'courses/',
body=json.dumps(body),
status=200,
content_type='application/json'
......@@ -474,9 +490,12 @@ class DrupalApiDataLoaderTests(DataLoaderTestMixin, TestCase):
class EcommerceApiDataLoaderTests(DataLoaderTestMixin, TestCase):
loader_class = EcommerceApiDataLoader
def mock_api(self):
@property
def api_url(self):
return self.partner.ecommerce_api_url
# create existing seats to be removed by ingest
def mock_api(self):
# Create existing seats to be removed by ingest
audit_run = CourseRunFactory(title_override='audit', key='audit/course/run')
verified_run = CourseRunFactory(title_override='verified', key='verified/course/run')
credit_run = CourseRunFactory(title_override='credit', key='credit/course/run')
......@@ -488,7 +507,7 @@ class EcommerceApiDataLoaderTests(DataLoaderTestMixin, TestCase):
SeatFactory(course_run=no_currency_run, type=Seat.PROFESSIONAL)
bodies = mock_data.ECOMMERCE_API_BODIES
url = self.partner.ecommerce_api_url + 'courses/'
url = self.api_url + 'courses/'
responses.add_callback(
responses.GET,
url,
......@@ -585,9 +604,13 @@ class EcommerceApiDataLoaderTests(DataLoaderTestMixin, TestCase):
class ProgramsApiDataLoaderTests(DataLoaderTestMixin, TestCase):
loader_class = ProgramsApiDataLoader
@property
def api_url(self):
return self.partner.programs_api_url
def mock_api(self):
bodies = mock_data.PROGRAMS_API_BODIES
url = self.partner.programs_api_url + 'programs/'
url = self.api_url + 'programs/'
responses.add_callback(
responses.GET,
url,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment