Commit 95e8d8a6 by Douglas Hall Committed by Douglas Hall

Refactor LMSAPIClient to use discovery service user instead of request user.

ENT-657
parent 42b3ed3c
......@@ -159,7 +159,7 @@ def get_utm_source_for_user(partner, user):
# If use_company_name_as_utm_source_value is enabled and lms_url value is set then
# use company name from API Access Request as utm_source.
if waffle.switch_is_active('use_company_name_as_utm_source_value') and partner.lms_url:
lms = LMSAPIClient(partner.site, user)
lms = LMSAPIClient(partner.site)
api_access_request = lms.get_api_access_request(user)
if api_access_request:
......
......@@ -4,6 +4,7 @@ import itertools
from urllib.parse import urlencode
import ddt
import mock
import pytest
import pytz
import responses
......@@ -33,7 +34,7 @@ from course_discovery.apps.api.serializers import (AffiliateWindowSerializer, Ca
VideoSerializer, get_utm_source_for_user)
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User
from course_discovery.apps.core.models import Partner, User
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin, LMSAPIClientMixin
......@@ -1409,7 +1410,8 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
self.partner = PartnerFactory.create()
@override_switch('use_company_name_as_utm_source_value', active=False)
def test_with_waffle_switch_turned_off(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_with_waffle_switch_turned_off(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_utm_source_for_user` returns User's username when waffle switch
`use_company_name_as_utm_source_value` is turned off.
......@@ -1417,7 +1419,8 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
def test_with_missing_lms_url(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_with_missing_lms_url(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_utm_source_for_user` returns default value if
`Partner.lms_url` is not set in the database.
......@@ -1429,16 +1432,18 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate
def test_when_api_response_is_not_valid(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_when_api_response_is_not_valid(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_utm_source_for_user` returns default value if
LMS API does not return a valid response.
"""
self.mock_api_access_request(self.partner.lms_url, status=400)
self.mock_api_access_request(self.partner.lms_url, self.user, status=400)
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate
def test_get_utm_source_for_user(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_utm_source_for_user(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_utm_source_for_user` returns correct value.
"""
......@@ -1446,6 +1451,6 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
expected_utm_source = slugify('{} {}'.format(self.user.username, company_name))
self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides={'company_name': company_name},
self.partner.lms_url, self.user, api_access_request_overrides={'company_name': company_name},
)
assert get_utm_source_for_user(self.partner, self.user) == expected_utm_source
......@@ -18,10 +18,8 @@ class LMSAPIClient(object):
API Client for communication between discovery and LMS.
"""
def __init__(self, site, user):
self.site = site
self.user = user
self.client = EdxRestApiClient(self.site.partner.lms_url, oauth_access_token=user.access_token)
def __init__(self, site):
self.client = EdxRestApiClient(site.partner.lms_url, jwt=site.partner.access_token)
def get_api_access_request(self, user):
"""
......@@ -51,14 +49,9 @@ class LMSAPIClient(object):
}
"""
resource = 'api-admin/api/v1/api_access_request/'
query_parameters = {}
# Since Staff user has access to all API Access Requests and we want to limit the response.
# So, we will filter by username for staff users.
if user.is_staff:
query_parameters = {
'user__username': user.username
}
query_parameters = {
'user__username': user.username
}
cache_key = get_cache_key(username=user.username, resource=resource)
api_access_request = cache.get(cache_key)
......@@ -78,7 +71,6 @@ class LMSAPIClient(object):
except (SlumberBaseException, ConnectionError, Timeout):
logger.exception('Failed to fetch API Access Request from LMS for user "%s".', user.username)
except (IndexError, KeyError):
# This should not happen as user must always have at-least one api-access-request.
logger.exception('APIAccessRequest model not found for user [%s].', user.username)
logger.info('APIAccessRequest model not found for user [%s].', user.username)
return api_access_request
......@@ -42,7 +42,7 @@ class ElasticsearchTestMixin(object):
class LMSAPIClientMixin(object):
def mock_api_access_request(self, lms_url, status=200, api_access_request_overrides=None):
def mock_api_access_request(self, lms_url, user, status=200, api_access_request_overrides=None):
"""
Mock the api access requests endpoint response of the LMS.
"""
......@@ -76,13 +76,13 @@ class LMSAPIClientMixin(object):
responses.add(
responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/',
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/?user__username={}'.format(user.username),
body=json.dumps(data),
content_type='application/json',
status=status
)
def mock_api_access_request_with_invalid_data(self, lms_url, status=200, response_overrides=None):
def mock_api_access_request_with_invalid_data(self, lms_url, user, status=200, response_overrides=None):
"""
Mock the api access requests endpoint response of the LMS.
"""
......@@ -90,7 +90,7 @@ class LMSAPIClientMixin(object):
responses.add(
responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/',
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/?user__username={}'.format(user.username),
body=json.dumps(data),
content_type='application/json',
status=status
......
import logging
import mock
import responses
from django.test import TestCase
from course_discovery.apps.core.api_client import lms
from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import LMSAPIClientMixin
from course_discovery.apps.core.tests.utils import MockLoggingHandler
......@@ -19,7 +21,8 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
logger.addHandler(cls.log_handler)
cls.log_messages = cls.log_handler.messages
def setUp(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def setUp(self, mock_access_token): # pylint: disable=unused-argument
super(TestLMSAPIClient, self).setUp()
# Reset mock logger for each test.
......@@ -27,7 +30,7 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
self.user = UserFactory.create()
self.partner = PartnerFactory.create()
self.lms = lms.LMSAPIClient(self.partner.site, self.user)
self.lms = lms.LMSAPIClient(self.partner.site)
self.response = {
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
......@@ -43,43 +46,47 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
}
@responses.activate
def test_get_api_access_request(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_api_access_request` returns correct value.
"""
self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides=self.response
self.partner.lms_url, self.user, api_access_request_overrides=self.response
)
assert self.lms.get_api_access_request(self.user) == self.response
@responses.activate
def test_get_api_access_request_with_404_error(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_404_error(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
self.mock_api_access_request(
self.partner.lms_url, status=404
self.partner.lms_url, self.user, status=404
)
assert self.lms.get_api_access_request(self.user) is None
assert 'Failed to fetch API Access Request from LMS for user "%s".' % self.user.username in \
self.log_messages['error']
@responses.activate
def test_get_api_access_request_with_empty_response(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_empty_response(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url
self.partner.lms_url, self.user
)
assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error']
self.log_messages['info']
@responses.activate
def test_get_api_access_request_with_invalid_response(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_invalid_response(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
......@@ -101,14 +108,15 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
}
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_invalid_response
self.partner.lms_url, self.user, response_overrides=sample_invalid_response
)
assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error']
self.log_messages['info']
@responses.activate
def test_get_api_access_request_with_multiple_records(self):
@mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_multiple_records(self, mock_access_token): # pylint: disable=unused-argument
"""
Verify that `get_api_access_request` logs a warning message and returns the first result
if endpoint returns multiple api-access-requests for a user.
......@@ -154,7 +162,7 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
}
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_response_with_multiple_users
self.partner.lms_url, self.user, response_overrides=sample_response_with_multiple_users
)
assert self.lms.get_api_access_request(self.user)['company_name'] == 'Test Company'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment