Commit 95e8d8a6 by Douglas Hall Committed by Douglas Hall

Refactor LMSAPIClient to use discovery service user instead of request user.

ENT-657
parent 42b3ed3c
...@@ -159,7 +159,7 @@ def get_utm_source_for_user(partner, user): ...@@ -159,7 +159,7 @@ def get_utm_source_for_user(partner, user):
# If use_company_name_as_utm_source_value is enabled and lms_url value is set then # If use_company_name_as_utm_source_value is enabled and lms_url value is set then
# use company name from API Access Request as utm_source. # use company name from API Access Request as utm_source.
if waffle.switch_is_active('use_company_name_as_utm_source_value') and partner.lms_url: if waffle.switch_is_active('use_company_name_as_utm_source_value') and partner.lms_url:
lms = LMSAPIClient(partner.site, user) lms = LMSAPIClient(partner.site)
api_access_request = lms.get_api_access_request(user) api_access_request = lms.get_api_access_request(user)
if api_access_request: if api_access_request:
......
...@@ -4,6 +4,7 @@ import itertools ...@@ -4,6 +4,7 @@ import itertools
from urllib.parse import urlencode from urllib.parse import urlencode
import ddt import ddt
import mock
import pytest import pytest
import pytz import pytz
import responses import responses
...@@ -33,7 +34,7 @@ from course_discovery.apps.api.serializers import (AffiliateWindowSerializer, Ca ...@@ -33,7 +34,7 @@ from course_discovery.apps.api.serializers import (AffiliateWindowSerializer, Ca
VideoSerializer, get_utm_source_for_user) VideoSerializer, get_utm_source_for_user)
from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import Partner, User
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin, LMSAPIClientMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin, LMSAPIClientMixin
...@@ -1409,7 +1410,8 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase): ...@@ -1409,7 +1410,8 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
self.partner = PartnerFactory.create() self.partner = PartnerFactory.create()
@override_switch('use_company_name_as_utm_source_value', active=False) @override_switch('use_company_name_as_utm_source_value', active=False)
def test_with_waffle_switch_turned_off(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_with_waffle_switch_turned_off(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_utm_source_for_user` returns User's username when waffle switch Verify that `get_utm_source_for_user` returns User's username when waffle switch
`use_company_name_as_utm_source_value` is turned off. `use_company_name_as_utm_source_value` is turned off.
...@@ -1417,7 +1419,8 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase): ...@@ -1417,7 +1419,8 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
assert get_utm_source_for_user(self.partner, self.user) == self.user.username assert get_utm_source_for_user(self.partner, self.user) == self.user.username
def test_with_missing_lms_url(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_with_missing_lms_url(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_utm_source_for_user` returns default value if Verify that `get_utm_source_for_user` returns default value if
`Partner.lms_url` is not set in the database. `Partner.lms_url` is not set in the database.
...@@ -1429,16 +1432,18 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase): ...@@ -1429,16 +1432,18 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
assert get_utm_source_for_user(self.partner, self.user) == self.user.username assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate @responses.activate
def test_when_api_response_is_not_valid(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_when_api_response_is_not_valid(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_utm_source_for_user` returns default value if Verify that `get_utm_source_for_user` returns default value if
LMS API does not return a valid response. LMS API does not return a valid response.
""" """
self.mock_api_access_request(self.partner.lms_url, status=400) self.mock_api_access_request(self.partner.lms_url, self.user, status=400)
assert get_utm_source_for_user(self.partner, self.user) == self.user.username assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate @responses.activate
def test_get_utm_source_for_user(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_utm_source_for_user(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_utm_source_for_user` returns correct value. Verify that `get_utm_source_for_user` returns correct value.
""" """
...@@ -1446,6 +1451,6 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase): ...@@ -1446,6 +1451,6 @@ class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
expected_utm_source = slugify('{} {}'.format(self.user.username, company_name)) expected_utm_source = slugify('{} {}'.format(self.user.username, company_name))
self.mock_api_access_request( self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides={'company_name': company_name}, self.partner.lms_url, self.user, api_access_request_overrides={'company_name': company_name},
) )
assert get_utm_source_for_user(self.partner, self.user) == expected_utm_source assert get_utm_source_for_user(self.partner, self.user) == expected_utm_source
...@@ -18,10 +18,8 @@ class LMSAPIClient(object): ...@@ -18,10 +18,8 @@ class LMSAPIClient(object):
API Client for communication between discovery and LMS. API Client for communication between discovery and LMS.
""" """
def __init__(self, site, user): def __init__(self, site):
self.site = site self.client = EdxRestApiClient(site.partner.lms_url, jwt=site.partner.access_token)
self.user = user
self.client = EdxRestApiClient(self.site.partner.lms_url, oauth_access_token=user.access_token)
def get_api_access_request(self, user): def get_api_access_request(self, user):
""" """
...@@ -51,14 +49,9 @@ class LMSAPIClient(object): ...@@ -51,14 +49,9 @@ class LMSAPIClient(object):
} }
""" """
resource = 'api-admin/api/v1/api_access_request/' resource = 'api-admin/api/v1/api_access_request/'
query_parameters = {} query_parameters = {
'user__username': user.username
# Since Staff user has access to all API Access Requests and we want to limit the response. }
# So, we will filter by username for staff users.
if user.is_staff:
query_parameters = {
'user__username': user.username
}
cache_key = get_cache_key(username=user.username, resource=resource) cache_key = get_cache_key(username=user.username, resource=resource)
api_access_request = cache.get(cache_key) api_access_request = cache.get(cache_key)
...@@ -78,7 +71,6 @@ class LMSAPIClient(object): ...@@ -78,7 +71,6 @@ class LMSAPIClient(object):
except (SlumberBaseException, ConnectionError, Timeout): except (SlumberBaseException, ConnectionError, Timeout):
logger.exception('Failed to fetch API Access Request from LMS for user "%s".', user.username) logger.exception('Failed to fetch API Access Request from LMS for user "%s".', user.username)
except (IndexError, KeyError): except (IndexError, KeyError):
# This should not happen as user must always have at-least one api-access-request. logger.info('APIAccessRequest model not found for user [%s].', user.username)
logger.exception('APIAccessRequest model not found for user [%s].', user.username)
return api_access_request return api_access_request
...@@ -42,7 +42,7 @@ class ElasticsearchTestMixin(object): ...@@ -42,7 +42,7 @@ class ElasticsearchTestMixin(object):
class LMSAPIClientMixin(object): class LMSAPIClientMixin(object):
def mock_api_access_request(self, lms_url, status=200, api_access_request_overrides=None): def mock_api_access_request(self, lms_url, user, status=200, api_access_request_overrides=None):
""" """
Mock the api access requests endpoint response of the LMS. Mock the api access requests endpoint response of the LMS.
""" """
...@@ -76,13 +76,13 @@ class LMSAPIClientMixin(object): ...@@ -76,13 +76,13 @@ class LMSAPIClientMixin(object):
responses.add( responses.add(
responses.GET, responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/', lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/?user__username={}'.format(user.username),
body=json.dumps(data), body=json.dumps(data),
content_type='application/json', content_type='application/json',
status=status status=status
) )
def mock_api_access_request_with_invalid_data(self, lms_url, status=200, response_overrides=None): def mock_api_access_request_with_invalid_data(self, lms_url, user, status=200, response_overrides=None):
""" """
Mock the api access requests endpoint response of the LMS. Mock the api access requests endpoint response of the LMS.
""" """
...@@ -90,7 +90,7 @@ class LMSAPIClientMixin(object): ...@@ -90,7 +90,7 @@ class LMSAPIClientMixin(object):
responses.add( responses.add(
responses.GET, responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/', lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/?user__username={}'.format(user.username),
body=json.dumps(data), body=json.dumps(data),
content_type='application/json', content_type='application/json',
status=status status=status
......
import logging import logging
import mock
import responses import responses
from django.test import TestCase from django.test import TestCase
from course_discovery.apps.core.api_client import lms from course_discovery.apps.core.api_client import lms
from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import LMSAPIClientMixin from course_discovery.apps.core.tests.mixins import LMSAPIClientMixin
from course_discovery.apps.core.tests.utils import MockLoggingHandler from course_discovery.apps.core.tests.utils import MockLoggingHandler
...@@ -19,7 +21,8 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase): ...@@ -19,7 +21,8 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
logger.addHandler(cls.log_handler) logger.addHandler(cls.log_handler)
cls.log_messages = cls.log_handler.messages cls.log_messages = cls.log_handler.messages
def setUp(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def setUp(self, mock_access_token): # pylint: disable=unused-argument
super(TestLMSAPIClient, self).setUp() super(TestLMSAPIClient, self).setUp()
# Reset mock logger for each test. # Reset mock logger for each test.
...@@ -27,7 +30,7 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase): ...@@ -27,7 +30,7 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
self.user = UserFactory.create() self.user = UserFactory.create()
self.partner = PartnerFactory.create() self.partner = PartnerFactory.create()
self.lms = lms.LMSAPIClient(self.partner.site, self.user) self.lms = lms.LMSAPIClient(self.partner.site)
self.response = { self.response = {
'id': 1, 'id': 1,
'created': '2017-09-25T08:37:05.872566Z', 'created': '2017-09-25T08:37:05.872566Z',
...@@ -43,43 +46,47 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase): ...@@ -43,43 +46,47 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
} }
@responses.activate @responses.activate
def test_get_api_access_request(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_api_access_request` returns correct value. Verify that `get_api_access_request` returns correct value.
""" """
self.mock_api_access_request( self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides=self.response self.partner.lms_url, self.user, api_access_request_overrides=self.response
) )
assert self.lms.get_api_access_request(self.user) == self.response assert self.lms.get_api_access_request(self.user) == self.response
@responses.activate @responses.activate
def test_get_api_access_request_with_404_error(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_404_error(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_api_access_request` returns None when api_access_request Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available. API endpoint is not available.
""" """
self.mock_api_access_request( self.mock_api_access_request(
self.partner.lms_url, status=404 self.partner.lms_url, self.user, status=404
) )
assert self.lms.get_api_access_request(self.user) is None assert self.lms.get_api_access_request(self.user) is None
assert 'Failed to fetch API Access Request from LMS for user "%s".' % self.user.username in \ assert 'Failed to fetch API Access Request from LMS for user "%s".' % self.user.username in \
self.log_messages['error'] self.log_messages['error']
@responses.activate @responses.activate
def test_get_api_access_request_with_empty_response(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_empty_response(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_api_access_request` returns None when api_access_request Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available. API endpoint is not available.
""" """
self.mock_api_access_request_with_invalid_data( self.mock_api_access_request_with_invalid_data(
self.partner.lms_url self.partner.lms_url, self.user
) )
assert self.lms.get_api_access_request(self.user) is None assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \ assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error'] self.log_messages['info']
@responses.activate @responses.activate
def test_get_api_access_request_with_invalid_response(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_invalid_response(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_api_access_request` returns None when api_access_request Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available. API endpoint is not available.
...@@ -101,14 +108,15 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase): ...@@ -101,14 +108,15 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
} }
self.mock_api_access_request_with_invalid_data( self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_invalid_response self.partner.lms_url, self.user, response_overrides=sample_invalid_response
) )
assert self.lms.get_api_access_request(self.user) is None assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \ assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error'] self.log_messages['info']
@responses.activate @responses.activate
def test_get_api_access_request_with_multiple_records(self): @mock.patch.object(Partner, 'access_token', return_value='JWT fake')
def test_get_api_access_request_with_multiple_records(self, mock_access_token): # pylint: disable=unused-argument
""" """
Verify that `get_api_access_request` logs a warning message and returns the first result Verify that `get_api_access_request` logs a warning message and returns the first result
if endpoint returns multiple api-access-requests for a user. if endpoint returns multiple api-access-requests for a user.
...@@ -154,7 +162,7 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase): ...@@ -154,7 +162,7 @@ class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
} }
self.mock_api_access_request_with_invalid_data( self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_response_with_multiple_users self.partner.lms_url, self.user, response_overrides=sample_response_with_multiple_users
) )
assert self.lms.get_api_access_request(self.user)['company_name'] == 'Test Company' assert self.lms.get_api_access_request(self.user)['company_name'] == 'Test Company'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment