Commit c3348a15 by Clinton Blackburn Committed by Peter Fogg

Updated data loading code

- Management command now retrieves JWT access tokens
- Data loaders now properly initialize API clients with the correct authentication mechanism

 ECOM-4414
parent aef89279
...@@ -7,6 +7,7 @@ from urllib.parse import urljoin ...@@ -7,6 +7,7 @@ from urllib.parse import urljoin
import html2text import html2text
from dateutil.parser import parse from dateutil.parser import parse
from django.conf import settings from django.conf import settings
from django.utils.functional import cached_property
from edx_rest_api_client.client import EdxRestApiClient from edx_rest_api_client.client import EdxRestApiClient
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
...@@ -29,15 +30,40 @@ class AbstractDataLoader(metaclass=abc.ABCMeta): ...@@ -29,15 +30,40 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
""" """
PAGE_SIZE = 50 PAGE_SIZE = 50
SUPPORTED_TOKEN_TYPES = ('bearer', 'jwt',)
def __init__(self, api_url, access_token=None): def __init__(self, api_url, access_token, token_type):
""" """
Arguments: Arguments:
api_url (str): URL of the API from which data is loaded api_url (str): URL of the API from which data is loaded
access_token (str): OAuth2 access token access_token (str): OAuth2 access token
token_type (str): The type of access token passed in (e.g. Bearer, JWT)
""" """
token_type = token_type.lower()
if token_type not in self.SUPPORTED_TOKEN_TYPES:
raise ValueError('The token type {token_type} is invalid!'.format(token_type=token_type))
self.access_token = access_token self.access_token = access_token
self.api_url = api_url self.api_url = api_url
self.token_type = token_type
@cached_property
def api_client(self):
"""
Returns an authenticated API client ready to call the API from which data is loaded.
Returns:
EdxRestApiClient
"""
kwargs = {}
if self.token_type == 'jwt':
kwargs['jwt'] = self.access_token
else:
kwargs['oauth_access_token'] = self.access_token
return EdxRestApiClient(self.api_url, **kwargs)
@abc.abstractmethod @abc.abstractmethod
def ingest(self): # pragma: no cover def ingest(self): # pragma: no cover
...@@ -100,7 +126,7 @@ class OrganizationsApiDataLoader(AbstractDataLoader): ...@@ -100,7 +126,7 @@ class OrganizationsApiDataLoader(AbstractDataLoader):
""" Loads organizations from the Organizations API. """ """ Loads organizations from the Organizations API. """
def ingest(self): def ingest(self):
client = EdxRestApiClient(self.api_url, oauth_access_token=self.access_token) client = self.api_client
count = None count = None
page = 1 page = 1
...@@ -142,7 +168,7 @@ class CoursesApiDataLoader(AbstractDataLoader): ...@@ -142,7 +168,7 @@ class CoursesApiDataLoader(AbstractDataLoader):
""" Loads course runs from the Courses API. """ """ Loads course runs from the Courses API. """
def ingest(self): def ingest(self):
client = EdxRestApiClient(self.api_url, oauth_access_token=self.access_token) client = self.api_client
count = None count = None
page = 1 page = 1
...@@ -237,7 +263,7 @@ class DrupalApiDataLoader(AbstractDataLoader): ...@@ -237,7 +263,7 @@ class DrupalApiDataLoader(AbstractDataLoader):
"""Loads course runs from the Drupal API.""" """Loads course runs from the Drupal API."""
def ingest(self): def ingest(self):
client = EdxRestApiClient(self.api_url) client = self.api_client
logger.info('Refreshing Courses and CourseRuns from %s...', self.api_url) logger.info('Refreshing Courses and CourseRuns from %s...', self.api_url)
response = client.courses.get() response = client.courses.get()
...@@ -359,7 +385,7 @@ class EcommerceApiDataLoader(AbstractDataLoader): ...@@ -359,7 +385,7 @@ class EcommerceApiDataLoader(AbstractDataLoader):
""" Loads course seats from the E-Commerce API. """ """ Loads course seats from the E-Commerce API. """
def ingest(self): def ingest(self):
client = EdxRestApiClient(self.api_url, oauth_access_token=self.access_token) client = self.api_client
count = None count = None
page = 1 page = 1
......
import logging import logging
from django.conf import settings from django.conf import settings
from django.core.management import BaseCommand from django.core.management import BaseCommand, CommandError
from edx_rest_api_client.client import EdxRestApiClient from edx_rest_api_client.client import EdxRestApiClient
from course_discovery.apps.course_metadata.data_loaders import ( from course_discovery.apps.course_metadata.data_loaders import (
...@@ -23,23 +23,45 @@ class Command(BaseCommand): ...@@ -23,23 +23,45 @@ class Command(BaseCommand):
help='OAuth2 access token used to authenticate API calls.' help='OAuth2 access token used to authenticate API calls.'
) )
parser.add_argument(
'--token_type',
action='store',
dest='token_type',
default=None,
help='The type of access token being passed (e.g. Bearer, JWT).'
)
def handle(self, *args, **options): def handle(self, *args, **options):
access_token = options.get('access_token') access_token = options.get('access_token')
token_type = options.get('token_type')
if access_token and not token_type:
raise CommandError('The token_type must be specified when passing in an access token!')
if not access_token: if not access_token:
logger.info('No access token provided. Retrieving access token using client_credential flow...') logger.info('No access token provided. Retrieving access token using client_credential flow...')
token_type = 'JWT'
try: try:
access_token, __ = EdxRestApiClient.get_oauth_access_token( access_token, __ = EdxRestApiClient.get_oauth_access_token(
'{root}/access_token'.format(root=settings.SOCIAL_AUTH_EDX_OIDC_URL_ROOT), '{root}/access_token'.format(root=settings.SOCIAL_AUTH_EDX_OIDC_URL_ROOT),
settings.SOCIAL_AUTH_EDX_OIDC_KEY, settings.SOCIAL_AUTH_EDX_OIDC_KEY,
settings.SOCIAL_AUTH_EDX_OIDC_SECRET settings.SOCIAL_AUTH_EDX_OIDC_SECRET,
token_type=token_type
) )
except Exception: except Exception:
logger.exception('No access token provided or acquired through client_credential flow.') logger.exception('No access token provided or acquired through client_credential flow.')
raise raise
OrganizationsApiDataLoader(settings.ORGANIZATIONS_API_URL, access_token).ingest() loaders = (
CoursesApiDataLoader(settings.COURSES_API_URL, access_token).ingest() (OrganizationsApiDataLoader, settings.ORGANIZATIONS_API_URL,),
EcommerceApiDataLoader(settings.ECOMMERCE_API_URL, access_token).ingest() (CoursesApiDataLoader, settings.COURSES_API_URL,),
DrupalApiDataLoader(settings.MARKETING_API_URL).ingest() (EcommerceApiDataLoader, settings.ECOMMERCE_API_URL,),
(DrupalApiDataLoader, settings.MARKETING_API_URL,),
)
for loader_class, api_url in loaders:
try:
loader_class(api_url, access_token, token_type).ingest()
except Exception:
logger.exception('%s failed!', loader_class.__name__)
...@@ -8,6 +8,8 @@ import ddt ...@@ -8,6 +8,8 @@ import ddt
import responses import responses
from django.conf import settings from django.conf import settings
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from edx_rest_api_client.auth import BearerAuth, SuppliedJwtAuth
from edx_rest_api_client.client import EdxRestApiClient
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from pytz import UTC from pytz import UTC
...@@ -22,6 +24,7 @@ from course_discovery.apps.course_metadata.tests.factories import ( ...@@ -22,6 +24,7 @@ from course_discovery.apps.course_metadata.tests.factories import (
) )
ACCESS_TOKEN = 'secret' ACCESS_TOKEN = 'secret'
ACCESS_TOKEN_TYPE = 'Bearer'
COURSES_API_URL = 'https://lms.example.com/api/courses/v1' COURSES_API_URL = 'https://lms.example.com/api/courses/v1'
ORGANIZATIONS_API_URL = 'https://lms.example.com/api/organizations/v0' ORGANIZATIONS_API_URL = 'https://lms.example.com/api/organizations/v0'
MARKETING_API_URL = 'https://example.com/api/catalog/v2/' MARKETING_API_URL = 'https://example.com/api/catalog/v2/'
...@@ -64,13 +67,15 @@ class AbstractDataLoaderTest(TestCase): ...@@ -64,13 +67,15 @@ class AbstractDataLoaderTest(TestCase):
self.assertFalse(instance.__class__.objects.filter(pk=instance.pk).exists()) # pylint: disable=no-member self.assertFalse(instance.__class__.objects.filter(pk=instance.pk).exists()) # pylint: disable=no-member
# pylint: disable=not-callable
@ddt.ddt
class DataLoaderTestMixin(object): class DataLoaderTestMixin(object):
api_url = None api_url = None
loader_class = None loader_class = None
def setUp(self): def setUp(self):
super(DataLoaderTestMixin, self).setUp() super(DataLoaderTestMixin, self).setUp()
self.loader = self.loader_class(self.api_url, ACCESS_TOKEN) # pylint: disable=not-callable self.loader = self.loader_class(self.api_url, ACCESS_TOKEN, ACCESS_TOKEN_TYPE)
def assert_api_called(self, expected_num_calls, check_auth=True): def assert_api_called(self, expected_num_calls, check_auth=True):
""" Asserts the API was called with the correct number of calls, and the appropriate Authorization header. """ """ Asserts the API was called with the correct number of calls, and the appropriate Authorization header. """
...@@ -82,8 +87,30 @@ class DataLoaderTestMixin(object): ...@@ -82,8 +87,30 @@ class DataLoaderTestMixin(object):
""" Verify the constructor sets the appropriate attributes. """ """ Verify the constructor sets the appropriate attributes. """
self.assertEqual(self.loader.api_url, self.api_url) self.assertEqual(self.loader.api_url, self.api_url)
self.assertEqual(self.loader.access_token, ACCESS_TOKEN) self.assertEqual(self.loader.access_token, ACCESS_TOKEN)
self.assertEqual(self.loader.token_type, ACCESS_TOKEN_TYPE.lower())
def test_init_with_unsupported_token_type(self):
""" Verify the constructor raises an error if an unsupported token type is passed in. """
with self.assertRaises(ValueError):
self.loader_class(self.api_url, ACCESS_TOKEN, 'not-supported')
@ddt.unpack
@ddt.data(
('Bearer', BearerAuth),
('JWT', SuppliedJwtAuth),
)
def test_api_client(self, token_type, expected_auth_class):
""" Verify the property returns an API client with the correct authentication. """
loader = self.loader_class(self.api_url, ACCESS_TOKEN, token_type)
client = loader.api_client
self.assertIsInstance(client, EdxRestApiClient)
# NOTE (CCB): My initial preference was to mock the constructor and ensure the correct auth arguments
# were passed. However, that seems nearly impossible. This is the next best alternative. It is brittle, and
# may break if we ever change the underlying request class of EdxRestApiClient.
self.assertIsInstance(client._store['session'].auth, expected_auth_class) # pylint: disable=protected-access
@ddt.ddt
@override_settings(ORGANIZATIONS_API_URL=ORGANIZATIONS_API_URL) @override_settings(ORGANIZATIONS_API_URL=ORGANIZATIONS_API_URL)
class OrganizationsApiDataLoaderTests(DataLoaderTestMixin, TestCase): class OrganizationsApiDataLoaderTests(DataLoaderTestMixin, TestCase):
api_url = ORGANIZATIONS_API_URL api_url = ORGANIZATIONS_API_URL
...@@ -378,8 +405,8 @@ class CoursesApiDataLoaderTests(DataLoaderTestMixin, TestCase): ...@@ -378,8 +405,8 @@ class CoursesApiDataLoaderTests(DataLoaderTestMixin, TestCase):
self.assertIsNone(actual) self.assertIsNone(actual)
@override_settings(MARKETING_API_URL=MARKETING_API_URL)
@ddt.ddt @ddt.ddt
@override_settings(MARKETING_API_URL=MARKETING_API_URL)
class DrupalApiDataLoaderTests(DataLoaderTestMixin, TestCase): class DrupalApiDataLoaderTests(DataLoaderTestMixin, TestCase):
EXISTING_COURSE_AND_RUN_DATA = ( EXISTING_COURSE_AND_RUN_DATA = (
{ {
...@@ -992,10 +1019,13 @@ class EcommerceApiDataLoaderTests(DataLoaderTestMixin, TestCase): ...@@ -992,10 +1019,13 @@ class EcommerceApiDataLoaderTests(DataLoaderTestMixin, TestCase):
({"attribute_values": []}, Seat.AUDIT), ({"attribute_values": []}, Seat.AUDIT),
({"attribute_values": [{'name': 'certificate_type', 'value': 'professional'}]}, 'professional'), ({"attribute_values": [{'name': 'certificate_type', 'value': 'professional'}]}, 'professional'),
( (
{"attribute_values": [ {
{'name': 'other_data', 'value': 'other'}, "attribute_values": [
{'name': 'certificate_type', 'value': 'credit'} {'name': 'other_data', 'value': 'other'},
]}, 'credit' {'name': 'certificate_type', 'value': 'credit'}
]
},
'credit'
), ),
({"attribute_values": [{'name': 'other_data', 'value': 'other'}]}, Seat.AUDIT), ({"attribute_values": [{'name': 'other_data', 'value': 'other'}]}, Seat.AUDIT),
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment