Commit 71849f13 by Clinton Blackburn Committed by Clinton Blackburn

Exposed update_index command via Management API

ECOM-4747
parent f099c72b
...@@ -5,13 +5,13 @@ from rest_framework.test import APITestCase ...@@ -5,13 +5,13 @@ from rest_framework.test import APITestCase
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
class RefreshCourseMetadataTests(APITestCase): class ManagementCommandViewTestMixin(object):
""" Tests for the refresh_course_metadata management endpoint. """ call_command_path = None
path = reverse('api:v1:management-refresh-course-metadata') command_name = None
call_command_path = 'course_discovery.apps.api.v1.views.call_command' path = None
def setUp(self): def setUp(self):
super(RefreshCourseMetadataTests, self).setUp() super(ManagementCommandViewTestMixin, self).setUp()
self.superuser = UserFactory(is_superuser=True) self.superuser = UserFactory(is_superuser=True)
self.client.force_authenticate(self.superuser) # pylint: disable=no-member self.client.force_authenticate(self.superuser) # pylint: disable=no-member
...@@ -20,12 +20,8 @@ class RefreshCourseMetadataTests(APITestCase): ...@@ -20,12 +20,8 @@ class RefreshCourseMetadataTests(APITestCase):
response = self.client.post(self.path) response = self.client.post(self.path)
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
def test_superuser_required(self): def test_non_superusers_denied(self):
""" Verify only superusers can access the endpoint. """ """ Verify access is denied to non-superusers. """
with mock.patch(self.call_command_path, return_value=None):
response = self.client.post(self.path)
self.assertEqual(response.status_code, 200)
# Anonymous user # Anonymous user
self.client.logout() self.client.logout()
self.assert_access_forbidden() self.assert_access_forbidden()
...@@ -42,8 +38,8 @@ class RefreshCourseMetadataTests(APITestCase): ...@@ -42,8 +38,8 @@ class RefreshCourseMetadataTests(APITestCase):
self.assert_successful_response('abc123') self.assert_successful_response('abc123')
def assert_successful_response(self, access_token=None): def assert_successful_response(self, access_token=None):
""" Asserts the endpoint called the refresh_course_metadata management command with the correct arguments, """ Asserts the endpoint called the correct management command with the correct arguments, and the endpoint
and the endpoint returns HTTP 200 with text/plain content type. """ returns HTTP 200 with text/plain content type. """
data = {'access_token': access_token} if access_token else None data = {'access_token': access_token} if access_token else None
with mock.patch(self.call_command_path, return_value=None) as mocked_call_command: with mock.patch(self.call_command_path, return_value=None) as mocked_call_command:
response = self.client.post(self.path, data) response = self.client.post(self.path, data)
...@@ -55,9 +51,26 @@ class RefreshCourseMetadataTests(APITestCase): ...@@ -55,9 +51,26 @@ class RefreshCourseMetadataTests(APITestCase):
expected = { expected = {
'settings': 'course_discovery.settings.test' 'settings': 'course_discovery.settings.test'
} }
if access_token:
expected['access_token'] = access_token
self.assertTrue(mocked_call_command.called) self.assertTrue(mocked_call_command.called)
self.assertEqual(args[0], 'refresh_course_metadata') self.assertEqual(args[0], self.command_name)
self.assertDictContainsSubset(expected, kwargs) self.assertDictContainsSubset(expected, kwargs)
class RefreshCourseMetadataTests(ManagementCommandViewTestMixin, APITestCase):
""" Tests for the refresh_course_metadata management endpoint. """
call_command_path = 'course_discovery.apps.api.v1.views.call_command'
command_name = 'refresh_course_metadata'
path = reverse('api:v1:management-refresh-course-metadata')
def test_success_response(self):
""" Verify a successful response calls the management command and returns the plain text output. """
super(RefreshCourseMetadataTests, self).test_success_response()
self.assert_successful_response(access_token='abc123')
class UpdateIndexTests(ManagementCommandViewTestMixin, APITestCase):
""" Tests for the update_index management endpoint. """
call_command_path = 'course_discovery.apps.api.v1.views.call_command'
command_name = 'update_index'
path = reverse('api:v1:management-update-index')
...@@ -293,7 +293,23 @@ class ManagementViewSet(viewsets.ViewSet): ...@@ -293,7 +293,23 @@ class ManagementViewSet(viewsets.ViewSet):
multiple: false multiple: false
""" """
access_token = request.data.get('access_token') access_token = request.data.get('access_token')
kwargs = {'access_token': access_token} if access_token else {}
name = 'refresh_course_metadata'
output = self.run_command(request, name, **kwargs)
return Response(output, content_type='text/plain')
@list_route(methods=['post'])
def update_index(self, request):
""" Update the search index. """
name = 'update_index'
output = self.run_command(request, name)
return Response(output, content_type='text/plain')
def run_command(self, request, name, **kwargs):
# Capture all output and logging # Capture all output and logging
out = StringIO() out = StringIO()
err = StringIO() err = StringIO()
...@@ -305,18 +321,13 @@ class ManagementViewSet(viewsets.ViewSet): ...@@ -305,18 +321,13 @@ class ManagementViewSet(viewsets.ViewSet):
log_handler.setFormatter(formatter) log_handler.setFormatter(formatter)
root_logger.addHandler(log_handler) root_logger.addHandler(log_handler)
logger.info('Updating course metadata per request of [%s]...', request.user.username) logger.info('Running [%s] per request of [%s]...', name, request.user.username)
call_command(name, settings=os.environ['DJANGO_SETTINGS_MODULE'], stdout=out, stderr=err, **kwargs)
kwargs = {'access_token': access_token} if access_token else {}
call_command('refresh_course_metadata', settings=os.environ['DJANGO_SETTINGS_MODULE'], stdout=out, stderr=err,
**kwargs)
# Format the output for display # Format the output for display
output = 'STDOUT\n{out}\n\nSTDERR\n{err}\n\nLOG\n{log}'.format(out=out.getvalue(), err=err.getvalue(), output = 'STDOUT\n{out}\n\nSTDERR\n{err}\n\nLOG\n{log}'.format(out=out.getvalue(), err=err.getvalue(),
log=log.getvalue()) log=log.getvalue())
return output
return Response(output, content_type='text/plain')
class AffiliateWindowViewSet(viewsets.ViewSet): class AffiliateWindowViewSet(viewsets.ViewSet):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment