Commit 9fd7ad79 by Matthew Piatetsky

Get username from token and remove token command line args

parent 365c35a7
...@@ -57,18 +57,6 @@ class ManagementCommandViewTestMixin(object): ...@@ -57,18 +57,6 @@ class ManagementCommandViewTestMixin(object):
self.assertDictContainsSubset(expected, kwargs) self.assertDictContainsSubset(expected, kwargs)
class RefreshCourseMetadataTests(ManagementCommandViewTestMixin, APITestCase):
""" Tests for the refresh_course_metadata management endpoint. """
call_command_path = 'course_discovery.apps.api.v1.views.call_command'
command_name = 'refresh_course_metadata'
path = reverse('api:v1:management-refresh-course-metadata')
def test_success_response(self):
""" Verify a successful response calls the management command and returns the plain text output. """
super(RefreshCourseMetadataTests, self).test_success_response()
self.assert_successful_response(access_token='abc123')
class UpdateIndexTests(ManagementCommandViewTestMixin, APITestCase): class UpdateIndexTests(ManagementCommandViewTestMixin, APITestCase):
""" Tests for the update_index management endpoint. """ """ Tests for the update_index management endpoint. """
call_command_path = 'course_discovery.apps.api.v1.views.call_command' call_command_path = 'course_discovery.apps.api.v1.views.call_command'
......
...@@ -517,26 +517,6 @@ class ManagementViewSet(viewsets.ViewSet): ...@@ -517,26 +517,6 @@ class ManagementViewSet(viewsets.ViewSet):
permission_classes = (IsSuperuser,) permission_classes = (IsSuperuser,)
@list_route(methods=['post']) @list_route(methods=['post'])
def refresh_course_metadata(self, request):
""" Refresh the course metadata from external data sources.
---
parameters:
- name: access_token
description: OAuth access token to use in lieu of that issued to the service.
required: false
type: string
paramType: form
multiple: false
"""
access_token = request.data.get('access_token')
kwargs = {'access_token': access_token} if access_token else {}
name = 'refresh_course_metadata'
output = self.run_command(request, name, **kwargs)
return Response(output, content_type='text/plain')
@list_route(methods=['post'])
def update_index(self, request): def update_index(self, request):
""" Update the search index. """ """ Update the search index. """
name = 'update_index' name = 'update_index'
......
...@@ -23,10 +23,10 @@ class AbstractDataLoader(metaclass=abc.ABCMeta): ...@@ -23,10 +23,10 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
""" """
PAGE_SIZE = 50 PAGE_SIZE = 50
SUPPORTED_TOKEN_TYPES = ('bearer', 'jwt',)
MARKDOWN_CLEANUP_REGEX = re.compile(r'^<p>(.*)</p>$') MARKDOWN_CLEANUP_REGEX = re.compile(r'^<p>(.*)</p>$')
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False): def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None,
is_threadsafe=False, **kwargs):
""" """
Arguments: Arguments:
partner (Partner): Partner which owns the APIs and data being loaded partner (Partner): Partner which owns the APIs and data being loaded
...@@ -39,9 +39,6 @@ class AbstractDataLoader(metaclass=abc.ABCMeta): ...@@ -39,9 +39,6 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
if token_type: if token_type:
token_type = token_type.lower() token_type = token_type.lower()
if token_type not in self.SUPPORTED_TOKEN_TYPES:
raise ValueError('The token type {token_type} is invalid!'.format(token_type=token_type))
self.access_token = access_token self.access_token = access_token
self.token_type = token_type self.token_type = token_type
self.partner = partner self.partner = partner
...@@ -49,6 +46,7 @@ class AbstractDataLoader(metaclass=abc.ABCMeta): ...@@ -49,6 +46,7 @@ class AbstractDataLoader(metaclass=abc.ABCMeta):
self.max_workers = max_workers self.max_workers = max_workers
self.is_threadsafe = is_threadsafe self.is_threadsafe = is_threadsafe
self.username = kwargs.get('username')
@cached_property @cached_property
def api_client(self): def api_client(self):
......
...@@ -96,7 +96,7 @@ class CoursesApiDataLoader(AbstractDataLoader): ...@@ -96,7 +96,7 @@ class CoursesApiDataLoader(AbstractDataLoader):
self._process_response(response) self._process_response(response)
def _make_request(self, page): def _make_request(self, page):
return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, username='course_discovery_worker') return self.api_client.courses().get(page=page, page_size=self.PAGE_SIZE, username=self.username)
def _process_response(self, response): def _process_response(self, response):
results = response['results'] results = response['results']
...@@ -294,9 +294,10 @@ class ProgramsApiDataLoader(AbstractDataLoader): ...@@ -294,9 +294,10 @@ class ProgramsApiDataLoader(AbstractDataLoader):
image_height = 480 image_height = 480
XSERIES = None XSERIES = None
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False): def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None,
is_threadsafe=False, **kwargs):
super(ProgramsApiDataLoader, self).__init__( super(ProgramsApiDataLoader, self).__init__(
partner, api_url, access_token, token_type, max_workers, is_threadsafe partner, api_url, access_token, token_type, max_workers, is_threadsafe, **kwargs
) )
self.XSERIES = ProgramType.objects.get(name='XSeries') self.XSERIES = ProgramType.objects.get(name='XSeries')
......
...@@ -23,9 +23,10 @@ logger = logging.getLogger(__name__) ...@@ -23,9 +23,10 @@ logger = logging.getLogger(__name__)
class AbstractMarketingSiteDataLoader(AbstractDataLoader): class AbstractMarketingSiteDataLoader(AbstractDataLoader):
def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None, is_threadsafe=False): def __init__(self, partner, api_url, access_token=None, token_type=None, max_workers=None,
is_threadsafe=False, **kwargs):
super(AbstractMarketingSiteDataLoader, self).__init__( super(AbstractMarketingSiteDataLoader, self).__init__(
partner, api_url, access_token, token_type, max_workers, is_threadsafe partner, api_url, access_token, token_type, max_workers, is_threadsafe, **kwargs
) )
if not (self.partner.marketing_site_api_username and self.partner.marketing_site_api_password): if not (self.partner.marketing_site_api_username and self.partner.marketing_site_api_password):
......
import ddt
import responses import responses
from edx_rest_api_client.auth import BearerAuth, SuppliedJwtAuth from edx_rest_api_client.auth import SuppliedJwtAuth
from edx_rest_api_client.client import EdxRestApiClient from edx_rest_api_client.client import EdxRestApiClient
from course_discovery.apps.course_metadata.tests.factories import PartnerFactory from course_discovery.apps.course_metadata.tests.factories import PartnerFactory
ACCESS_TOKEN = 'secret' ACCESS_TOKEN = 'secret'
ACCESS_TOKEN_TYPE = 'Bearer' ACCESS_TOKEN_TYPE = 'JWT'
@ddt.ddt
class ApiClientTestMixin(object): class ApiClientTestMixin(object):
@ddt.unpack def test_api_client(self):
@ddt.data(
('Bearer', BearerAuth),
('JWT', SuppliedJwtAuth),
)
def test_api_client(self, token_type, expected_auth_class):
""" Verify the property returns an API client with the correct authentication. """ """ Verify the property returns an API client with the correct authentication. """
loader = self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, token_type) loader = self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, ACCESS_TOKEN_TYPE)
client = loader.api_client client = loader.api_client
self.assertIsInstance(client, EdxRestApiClient) self.assertIsInstance(client, EdxRestApiClient)
# NOTE (CCB): My initial preference was to mock the constructor and ensure the correct auth arguments # NOTE (CCB): My initial preference was to mock the constructor and ensure the correct auth arguments
# were passed. However, that seems nearly impossible. This is the next best alternative. It is brittle, and # were passed. However, that seems nearly impossible. This is the next best alternative. It is brittle, and
# may break if we ever change the underlying request class of EdxRestApiClient. # may break if we ever change the underlying request class of EdxRestApiClient.
self.assertIsInstance(client._store['session'].auth, expected_auth_class) # pylint: disable=protected-access self.assertIsInstance(client._store['session'].auth, SuppliedJwtAuth) # pylint: disable=protected-access
# pylint: disable=not-callable # pylint: disable=not-callable
...@@ -45,15 +38,10 @@ class DataLoaderTestMixin(object): ...@@ -45,15 +38,10 @@ class DataLoaderTestMixin(object):
""" Asserts the API was called with the correct number of calls, and the appropriate Authorization header. """ """ Asserts the API was called with the correct number of calls, and the appropriate Authorization header. """
self.assertEqual(len(responses.calls), expected_num_calls) self.assertEqual(len(responses.calls), expected_num_calls)
if check_auth: if check_auth:
self.assertEqual(responses.calls[0].request.headers['Authorization'], 'Bearer {}'.format(ACCESS_TOKEN)) self.assertEqual(responses.calls[0].request.headers['Authorization'], 'JWT {}'.format(ACCESS_TOKEN))
def test_init(self): def test_init(self):
""" Verify the constructor sets the appropriate attributes. """ """ Verify the constructor sets the appropriate attributes. """
self.assertEqual(self.loader.partner.short_code, self.partner.short_code) self.assertEqual(self.loader.partner.short_code, self.partner.short_code)
self.assertEqual(self.loader.access_token, ACCESS_TOKEN) self.assertEqual(self.loader.access_token, ACCESS_TOKEN)
self.assertEqual(self.loader.token_type, ACCESS_TOKEN_TYPE.lower()) self.assertEqual(self.loader.token_type, ACCESS_TOKEN_TYPE.lower())
def test_init_with_unsupported_token_type(self):
""" Verify the constructor raises an error if an unsupported token type is passed in. """
with self.assertRaises(ValueError):
self.loader_class(self.partner, self.api_url, ACCESS_TOKEN, 'not-supported')
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
from django.core.management import BaseCommand, CommandError from django.core.management import BaseCommand, CommandError
from django.db import connection from django.db import connection
from edx_rest_api_client.client import EdxRestApiClient from edx_rest_api_client.client import EdxRestApiClient
import jwt
import waffle import waffle
from course_discovery.apps.core.models import Partner from course_discovery.apps.core.models import Partner
...@@ -21,9 +22,9 @@ from course_discovery.apps.course_metadata.models import Course ...@@ -21,9 +22,9 @@ from course_discovery.apps.course_metadata.models import Course
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def execute_loader(loader_class, *loader_args): def execute_loader(loader_class, *loader_args, **loader_kwargs):
try: try:
loader_class(*loader_args).ingest() loader_class(*loader_args, **loader_kwargs).ingest()
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception('%s failed!', loader_class.__name__) logger.exception('%s failed!', loader_class.__name__)
...@@ -52,22 +53,6 @@ class Command(BaseCommand): ...@@ -52,22 +53,6 @@ class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--access_token',
action='store',
dest='access_token',
default=None,
help='OAuth2 access token used to authenticate API calls.'
)
parser.add_argument(
'--token_type',
action='store',
dest='token_type',
default=None,
help='The type of access token being passed (e.g. Bearer, JWT).'
)
parser.add_argument(
'--partner_code', '--partner_code',
action='store', action='store',
dest='partner_code', dest='partner_code',
...@@ -98,27 +83,22 @@ class Command(BaseCommand): ...@@ -98,27 +83,22 @@ class Command(BaseCommand):
if not partners: if not partners:
raise CommandError('No partners available!') raise CommandError('No partners available!')
token_type = 'JWT'
for partner in partners: for partner in partners:
access_token = options.get('access_token') logger.info('Retrieving access token for partner [{}]'.format(partner_code))
token_type = options.get('token_type')
try:
if access_token and not token_type: access_token, __ = EdxRestApiClient.get_oauth_access_token(
raise CommandError('The token_type must be specified when passing in an access token!') '{root}/access_token'.format(root=partner.oidc_url_root.strip('/')),
partner.oidc_key,
if not access_token: partner.oidc_secret,
logger.info('No access token provided. Retrieving access token using client_credential flow...') token_type=token_type
token_type = 'JWT' )
except Exception:
try: logger.exception('No access token acquired through client_credential flow.')
access_token, __ = EdxRestApiClient.get_oauth_access_token( raise
'{root}/access_token'.format(root=partner.oidc_url_root.strip('/')), username = jwt.decode(access_token, verify=False)['preferred_username']
partner.oidc_key, kwargs = {'username': username} if username else {}
partner.oidc_secret,
token_type=token_type
)
except Exception:
logger.exception('No access token provided or acquired through client_credential flow.')
raise
# The Linux kernel implements copy-on-write when fork() is called to create a new # The Linux kernel implements copy-on-write when fork() is called to create a new
# process. Pages that the parent and child processes share, such as the database # process. Pages that the parent and child processes share, such as the database
...@@ -192,6 +172,7 @@ class Command(BaseCommand): ...@@ -192,6 +172,7 @@ class Command(BaseCommand):
token_type, token_type,
(max_workers_override or max_workers), (max_workers_override or max_workers),
is_threadsafe, is_threadsafe,
**kwargs,
) )
else: else:
# Flatten pipeline and run serially. # Flatten pipeline and run serially.
...@@ -205,6 +186,7 @@ class Command(BaseCommand): ...@@ -205,6 +186,7 @@ class Command(BaseCommand):
token_type, token_type,
(max_workers_override or max_workers), (max_workers_override or max_workers),
is_threadsafe, is_threadsafe,
**kwargs,
) )
# TODO Cleanup CourseRun overrides equivalent to the Course values. # TODO Cleanup CourseRun overrides equivalent to the Course values.
import json import json
import ddt import ddt
import jwt
import mock import mock
import responses import responses
from django.core.management import call_command, CommandError from django.core.management import call_command, CommandError
...@@ -19,8 +20,8 @@ from course_discovery.apps.course_metadata.data_loaders.tests import mock_data ...@@ -19,8 +20,8 @@ from course_discovery.apps.course_metadata.data_loaders.tests import mock_data
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import CourseFactory from course_discovery.apps.course_metadata.tests.factories import CourseFactory
ACCESS_TOKEN = 'secret'
JSON = 'application/json' JSON = 'application/json'
ACCESS_TOKEN = str(jwt.encode({'preferred_username': 'bob'}, 'secret'), 'utf-8')
@ddt.ddt @ddt.ddt
...@@ -134,11 +135,15 @@ class RefreshCourseMetadataCommandTests(TransactionTestCase): ...@@ -134,11 +135,15 @@ class RefreshCourseMetadataCommandTests(TransactionTestCase):
command_args = ['--partner_code=invalid'] command_args = ['--partner_code=invalid']
call_command('refresh_course_metadata', *command_args) call_command('refresh_course_metadata', *command_args)
def test_refresh_course_metadata_with_no_token_type(self): def test_refresh_course_metadata_errors_with_no_token(self):
""" Verify an error is raised if an access token is passed in without a token type. """ """ Verify an exception is raised and an error is logged if an access token is not acquired. """
with self.assertRaises(CommandError): with mock.patch('edx_rest_api_client.client.EdxRestApiClient.get_oauth_access_token', side_effect=Exception):
command_args = ['--access_token=test-access-token'] logger = 'course_discovery.apps.course_metadata.management.commands.refresh_course_metadata.logger'
call_command('refresh_course_metadata', *command_args) with mock.patch(logger) as mock_logger:
with self.assertRaises(Exception):
call_command('refresh_course_metadata')
expected_calls = [mock.call('No access token acquired through client_credential flow.')]
mock_logger.exception.assert_has_calls(expected_calls)
def test_refresh_course_metadata_with_loader_exception(self): def test_refresh_course_metadata_with_loader_exception(self):
""" Verify execution continues if an individual data loader fails. """ """ Verify execution continues if an individual data loader fails. """
......
...@@ -21,12 +21,11 @@ Retrieving Course Metadata ...@@ -21,12 +21,11 @@ Retrieving Course Metadata
The ``refresh_course_metadata`` command in :file:`course_discovery/apps/course_metadata/management/commands/refresh_course_metadata.py` is used to retrieve metadata. This is run daily in production through a Jenkins job, and can be manually run to set up your local environment. The data loaders are each run in series by the command. The data The ``refresh_course_metadata`` command in :file:`course_discovery/apps/course_metadata/management/commands/refresh_course_metadata.py` is used to retrieve metadata. This is run daily in production through a Jenkins job, and can be manually run to set up your local environment. The data loaders are each run in series by the command. The data
loaders should be idempotent -- that is, running this command once will populate the database, and if nothing has loaders should be idempotent -- that is, running this command once will populate the database, and if nothing has
changed upstream then running it again should not change the database. changed upstream then running it again should not change the database.
The command retrieves a JWT access token from the configured OpenID Connect provider.
For example, if you use the JWT access token "secret-ecommerce-key" to authenticate API calls, run the following:
.. code-block:: bash .. code-block:: bash
$ ./manage.py refresh_course_metadata --access_token secret-ecommerce-key --token_type jwt $ ./manage.py refresh_course_metadata
QuerySets QuerySets
--------- ---------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment