Commit 9fd7ad79 by Matthew Piatetsky

Get username from token and remove token command line args

parent 365c35a7
......@@ -57,18 +57,6 @@ class ManagementCommandViewTestMixin(object):
self.assertDictContainsSubset(expected, kwargs)
class RefreshCourseMetadataTests(ManagementCommandViewTestMixin, APITestCase):
""" Tests for the refresh_course_metadata management endpoint. """
call_command_path = 'course_discovery.apps.api.v1.views.call_command'
command_name = 'refresh_course_metadata'
path = reverse('api:v1:management-refresh-course-metadata')
def test_success_response(self):
""" Verify a successful response calls the management command and returns the plain text output. """
super(RefreshCourseMetadataTests, self).test_success_response()
self.assert_successful_response(access_token='abc123')
class UpdateIndexTests(ManagementCommandViewTestMixin, APITestCase):
""" Tests for the update_index management endpoint. """
call_command_path = 'course_discovery.apps.api.v1.views.call_command'
......
......@@ -517,26 +517,6 @@ class ManagementViewSet(viewsets.ViewSet):
permission_classes = (IsSuperuser,)
@list_route(methods=['post'])
def refresh_course_metadata(self, request):
""" Refresh the course metadata from external data sources.
---
parameters:
- name: access_token
description: OAuth access token to use in lieu of that issued to the service.
required: false
type: string
paramType: form
multiple: false
"""
access_token = request.data.get('access_token')
kwargs = {'access_token': access_token} if access_token else {}
name = 'refresh_course_metadata'
output = self.run_command(request, name, **kwargs)
return Response(output, content_type='text/plain')
@list_route(methods=['post'])
def update_index(self, request):
""" Update the search index. """
name = 'update_index'
......
......@@ -23,10 +23,10 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
"""
PAGE_SIZE = 50
SUPPORTED_TOKEN_TYPES = ('bearer', 'jwt',)
MARKDOWN_CLEANUP_REGEX = re.compile(r'^<p>(.*)</p>$')
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None,
is_threadsafe=False, **kwargs):
"""
Arguments:
partner (Partner): Partner which owns the APIs and data being loaded
......@@ -39,9 +39,6 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
if token_type:
token_type = token_type.lower()
if token_type not in self.SUPPORTED_TOKEN_TYPES:
raise ValueError('The token type {token_type} is invalid!'.format(token_type=token_type))
self.access_token = access_token
self.token_type = token_type
self.partner = partner
......@@ -49,6 +46,7 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
self.max_workers = max_workers
self.is_threadsafe = is_threadsafe
self.username = kwargs.get('username')
@cached_property
def api_client(self):
......
......@@ -96,7 +96,7 @@ class CoursesApiDataLoader(AbstractDataLoader):
self._process_response(response)
def _make_request(self, page):
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, username='course_discovery_worker')
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, username=self.username)
def _process_response(self, response):
results = response['results']
......@@ -294,9 +294,10 @@ class ProgramsApiDataLoader(AbstractDataLoader):
image_height = 480
XSERIES = None
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None,
is_threadsafe=False, **kwargs):
super(ProgramsApiDataLoader, self).__init__(
partner, api_url, access_token, token_type, max_workers, is_threadsafe
partner, api_url, access_token, token_type, max_workers, is_threadsafe, **kwargs
)
self.XSERIES = ProgramType.objects.get(name='XSeries')
......
......@@ -23,9 +23,10 @@ logger = logging.getLogger(__name__)
class AbstractMarketingSiteDataLoader(AbstractDataLoader):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None,
is_threadsafe=False, **kwargs):
super(AbstractMarketingSiteDataLoader, self).__init__(
partner, api_url, access_token, token_type, max_workers, is_threadsafe
partner, api_url, access_token, token_type, max_workers, is_threadsafe, **kwargs
)
if not (self.partner.marketing_site_api_username and self.partner.marketing_site_api_password):
......
import ddt
import responses
from edx_rest_api_client.auth import BearerAuth, SuppliedJwtAuth
from edx_rest_api_client.auth import SuppliedJwtAuth
from edx_rest_api_client.client import EdxRestApiClient
from course_discovery.apps.course_metadata.tests.factories import PartnerFactory
ACCESS_TOKEN = 'secret'
ACCESS_TOKEN_TYPE = 'Bearer'
ACCESS_TOKEN_TYPE = 'JWT'
@ddt.ddt
class ApiClientTestMixin(object):
@ddt.unpack
@ddt.data(
('Bearer', BearerAuth),
('JWT', SuppliedJwtAuth),
)
def test_api_client(self, token_type, expected_auth_class):
def test_api_client(self):
""" Verify the property returns an API client with the correct authentication. """
loader = self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, token_type)
loader = self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, ACCESS_TOKEN_TYPE)
client = loader.api_client
self.assertIsInstance(client, EdxRestApiClient)
# NOTE (CCB): My initial preference was to mock the constructor and ensure the correct auth arguments
# were passed. However, that seems nearly impossible. This is the next best alternative. It is brittle, and
# may break if we ever change the underlying request class of EdxRestApiClient.
self.assertIsInstance(client._store['session'].auth, expected_auth_class) # pylint: disable=protected-access
self.assertIsInstance(client._store['session'].auth, SuppliedJwtAuth) # pylint: disable=protected-access
# pylint: disable=not-callable
......@@ -45,15 +38,10 @@ class DataLoaderTestMixin(object):
""" Asserts the API was called with the correct number of calls, and the appropriate Authorization header. """
self.assertEqual(len(responses.calls), expected_num_calls)
if check_auth:
self.assertEqual(responses.calls[0].request.headers['Authorization'], 'Bearer {}'.format(ACCESS_TOKEN))
self.assertEqual(responses.calls[0].request.headers['Authorization'], 'JWT {}'.format(ACCESS_TOKEN))
def test_init(self):
""" Verify the constructor sets the appropriate attributes. """
self.assertEqual(self.loader.partner.short_code, self.partner.short_code)
self.assertEqual(self.loader.access_token, ACCESS_TOKEN)
self.assertEqual(self.loader.token_type, ACCESS_TOKEN_TYPE.lower())
def test_init_with_unsupported_token_type(self):
""" Verify the constructor raises an error if an unsupported token type is passed in. """
with self.assertRaises(ValueError):
self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, 'not-supported')
......@@ -5,6 +5,7 @@ import logging
from django.core.management import BaseCommand, CommandError
from django.db import connection
from edx_rest_api_client.client import EdxRestApiClient
import jwt
import waffle
from course_discovery.apps.core.models import Partner
......@@ -21,9 +22,9 @@ from course_discovery.apps.course_metadata.models import Course
logger = logging.getLogger(__name__)
def execute_loader(loader_class, *loader_args):
def execute_loader(loader_class, *loader_args, **loader_kwargs):
try:
loader_class(*loader_args).ingest()
loader_class(*loader_args, **loader_kwargs).ingest()
except Exception: # pylint: disable=broad-except
logger.exception('%s failed!', loader_class.__name__)
......@@ -52,22 +53,6 @@ class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
'--access_token',
action='store',
dest='access_token',
default=None,
help='OAuth2 access token used to authenticate API calls.'
)
parser.add_argument(
'--token_type',
action='store',
dest='token_type',
default=None,
help='The type of access token being passed (e.g. Bearer, JWT).'
)
parser.add_argument(
'--partner_code',
action='store',
dest='partner_code',
......@@ -98,16 +83,9 @@ class Command(BaseCommand):
if not partners:
raise CommandError('No partners available!')
for partner in partners:
access_token = options.get('access_token')
token_type = options.get('token_type')
if access_token and not token_type:
raise CommandError('The token_type must be specified when passing in an access token!')
if not access_token:
logger.info('No access token provided. Retrieving access token using client_credential flow...')
token_type = 'JWT'
for partner in partners:
logger.info('Retrieving access token for partner [{}]'.format(partner_code))
try:
access_token, __ = EdxRestApiClient.get_oauth_access_token(
......@@ -117,8 +95,10 @@ class Command(BaseCommand):
token_type=token_type
)
except Exception:
logger.exception('No access token provided or acquired through client_credential flow.')
logger.exception('No access token acquired through client_credential flow.')
raise
username = jwt.decode(access_token, verify=False)['preferred_username']
kwargs = {'username': username} if username else {}
# The Linux kernel implements copy-on-write when fork() is called to create a new
# process. Pages that the parent and child processes share, such as the database
......@@ -192,6 +172,7 @@ class Command(BaseCommand):
token_type,
(max_workers_override or max_workers),
is_threadsafe,
**kwargs,
)
else:
# Flatten pipeline and run serially.
......@@ -205,6 +186,7 @@ class Command(BaseCommand):
token_type,
(max_workers_override or max_workers),
is_threadsafe,
**kwargs,
)
# TODO Cleanup CourseRun overrides equivalent to the Course values.
import json
import ddt
import jwt
import mock
import responses
from django.core.management import call_command, CommandError
......@@ -19,8 +20,8 @@ from course_discovery.apps.course_metadata.data_loaders.tests import mock_data
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import CourseFactory
ACCESS_TOKEN = 'secret'
JSON = 'application/json'
ACCESS_TOKEN = str(jwt.encode({'preferred_username': 'bob'}, 'secret'), 'utf-8')
@ddt.ddt
......@@ -134,11 +135,15 @@ class RefreshCourseMetadataCommandTests(TransactionTestCase):
command_args = ['--partner_code=invalid']
call_command('refresh_course_metadata', *command_args)
def test_refresh_course_metadata_with_no_token_type(self):
""" Verify an error is raised if an access token is passed in without a token type. """
with self.assertRaises(CommandError):
command_args = ['--access_token=test-access-token']
call_command('refresh_course_metadata', *command_args)
def test_refresh_course_metadata_errors_with_no_token(self):
""" Verify an exception is raised and an error is logged if an access token is not acquired. """
with mock.patch('edx_rest_api_client.client.EdxRestApiClient.get_oauth_access_token', side_effect=Exception):
logger = 'course_discovery.apps.course_metadata.management.commands.refresh_course_metadata.logger'
with mock.patch(logger) as mock_logger:
with self.assertRaises(Exception):
call_command('refresh_course_metadata')
expected_calls = [mock.call('No access token acquired through client_credential flow.')]
mock_logger.exception.assert_has_calls(expected_calls)
def test_refresh_course_metadata_with_loader_exception(self):
""" Verify execution continues if an individual data loader fails. """
......
......@@ -21,12 +21,11 @@ Retrieving Course Metadata
The ``refresh_course_metadata`` command in :file:`course_discovery/apps/course_metadata/management/commands/refresh_course_metadata.py` is used to retrieve metadata. This is run daily in production through a Jenkins job, and can be manually run to set up your local environment. The data loaders are each run in series by the command. The data
loaders should be idempotent -- that is, running this command once will populate the database, and if nothing has
changed upstream then running it again should not change the database.
For example, if you use the JWT access token "secret-ecommerce-key" to authenticate API calls, run the following:
The command retrieves a JWT access token from the configured OpenID Connect provider.
.. code-block:: bash
$ ./manage.py refresh_course_metadata --access_token secret-ecommerce-key --token_type jwt
$ ./manage.py refresh_course_metadata
QuerySets
---------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment