Commit 295fa22d by Saleem Latif Committed by Douglas Hall

Change UTM Source value in API results returned by discovery API

parent ad1dca56
......@@ -4,9 +4,11 @@ import json
from urllib.parse import urlencode
import pytz
import waffle
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db.models.query import Prefetch
from django.utils.text import slugify
from django.utils.translation import ugettext_lazy as _
from drf_haystack.serializers import HaystackFacetSerializer, HaystackSerializer
from rest_framework import serializers
......@@ -15,6 +17,7 @@ from taggit_serializer.serializers import TaggitSerializer, TagListSerializerFie
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.api_client.lms import LMSAPIClient
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import (FAQ, CorporateEndorsement, Course, CourseRun, Endorsement,
Image, Organization, Person, PersonSocialNetwork, PersonWork,
......@@ -114,11 +117,12 @@ SELECT_RELATED_FIELDS = {
}
def get_marketing_url_for_user(user, marketing_url, exclude_utm=False):
def get_marketing_url_for_user(partner, user, marketing_url, exclude_utm=False):
"""
Return the given marketing URL with affiliate query parameters for the user.
Arguments:
partner (Partner): Partner instance containing information.
user (User): Used to construct UTM query parameters.
marketing_url (str | None): Base URL to which UTM parameters may be appended.
......@@ -134,12 +138,36 @@ def get_marketing_url_for_user(user, marketing_url, exclude_utm=False):
return marketing_url
else:
params = urlencode({
'utm_source': user.username,
'utm_source': get_utm_source_for_user(partner, user),
'utm_medium': user.referral_tracking_id,
})
return '{url}?{params}'.format(url=marketing_url, params=params)
def get_utm_source_for_user(partner, user):
"""
Return the utm source for the user.
Arguments:
partner (Partner): Partner instance containing information.
user (User): Used to construct UTM query parameters.
Returns:
str: username and company name slugified and combined together.
"""
utm_source = user.username
# If use_company_name_as_utm_source_value is enabled and lms_url value is set then
# use company name from API Access Request as utm_source.
if waffle.switch_is_active('use_company_name_as_utm_source_value') and partner.lms_url:
lms = LMSAPIClient(partner.site, user)
api_access_request = lms.get_api_access_request(user)
if api_access_request:
utm_source = '{} {}'.format(utm_source, api_access_request['company_name'])
return slugify(utm_source)
class TimestampModelSerializer(serializers.ModelSerializer):
"""Serializer for timestamped models."""
modified = serializers.DateTimeField()
......@@ -444,6 +472,7 @@ class MinimalCourseRunSerializer(TimestampModelSerializer):
def get_marketing_url(self, obj):
return get_marketing_url_for_user(
obj.course.partner,
self.context['request'].user,
obj.marketing_url,
exclude_utm=self.context.get('exclude_utm')
......@@ -585,6 +614,7 @@ class CourseSerializer(MinimalCourseSerializer):
def get_marketing_url(self, obj):
return get_marketing_url_for_user(
obj.partner,
self.context['request'].user,
obj.marketing_url,
exclude_utm=self.context.get('exclude_utm')
......
......@@ -6,35 +6,46 @@ from urllib.parse import urlencode
import ddt
import pytest
import pytz
import responses
from django.test import TestCase
from django.utils.text import slugify
from haystack.query import SearchQuerySet
from opaque_keys.edx.keys import CourseKey
from rest_framework.test import APIRequestFactory
from waffle.models import Switch
from waffle.testutils import override_switch
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.serializers import (
AffiliateWindowSerializer, CatalogSerializer, ContainedCourseRunsSerializer, ContainedCoursesSerializer,
CorporateEndorsementSerializer, CourseRunSearchSerializer, CourseRunSerializer, CourseRunWithProgramsSerializer,
CourseSearchSerializer, CourseSerializer, CourseWithProgramsSerializer, EndorsementSerializer, FAQSerializer,
FlattenedCourseRunWithCourseSerializer, ImageSerializer, MinimalCourseRunSerializer, MinimalCourseSerializer,
MinimalOrganizationSerializer, MinimalProgramCourseSerializer, MinimalProgramSerializer, NestedProgramSerializer,
OrganizationSerializer, PersonSerializer, PositionSerializer, PrerequisiteSerializer, ProgramSearchSerializer,
ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer,
TypeaheadProgramSearchSerializer, VideoSerializer
)
from course_discovery.apps.api.serializers import (AffiliateWindowSerializer, CatalogSerializer,
ContainedCourseRunsSerializer, ContainedCoursesSerializer,
CorporateEndorsementSerializer, CourseRunSearchSerializer,
CourseRunSerializer, CourseRunWithProgramsSerializer,
CourseSearchSerializer, CourseSerializer,
CourseWithProgramsSerializer, EndorsementSerializer, FAQSerializer,
FlattenedCourseRunWithCourseSerializer, ImageSerializer,
MinimalCourseRunSerializer, MinimalCourseSerializer,
MinimalOrganizationSerializer, MinimalProgramCourseSerializer,
MinimalProgramSerializer, NestedProgramSerializer,
OrganizationSerializer, PersonSerializer, PositionSerializer,
PrerequisiteSerializer, ProgramSearchSerializer, ProgramSerializer,
ProgramTypeSerializer, SeatSerializer, SubjectSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer,
VideoSerializer, get_utm_source_for_user)
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin, LMSAPIClientMixin
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
from course_discovery.apps.course_metadata.tests.factories import (
CorporateEndorsementFactory, CourseFactory, CourseRunFactory, EndorsementFactory, ExpectedLearningItemFactory,
ImageFactory, JobOutlookItemFactory, OrganizationFactory, PersonFactory, PositionFactory, PrerequisiteFactory,
ProgramFactory, ProgramTypeFactory, SeatFactory, SeatTypeFactory, SubjectFactory, VideoFactory
)
from course_discovery.apps.course_metadata.tests.factories import (CorporateEndorsementFactory, CourseFactory,
CourseRunFactory, EndorsementFactory,
ExpectedLearningItemFactory, ImageFactory,
JobOutlookItemFactory, OrganizationFactory,
PersonFactory, PositionFactory, PrerequisiteFactory,
ProgramFactory, ProgramTypeFactory, SeatFactory,
SeatTypeFactory, SubjectFactory, VideoFactory)
from course_discovery.apps.ietf_language_tags.models import LanguageTag
......@@ -1384,3 +1395,57 @@ class TestTypeaheadProgramSearchSerializer:
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
serializer = TypeaheadProgramSearchSerializer(result)
return serializer
class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
def setUp(self):
super(TestGetUTMSourceForUser, self).setUp()
self.switch, __ = Switch.objects.update_or_create(
name='use_company_name_as_utm_source_value', defaults={'active': True}
)
self.user = UserFactory.create()
self.partner = PartnerFactory.create()
@override_switch('use_company_name_as_utm_source_value', active=False)
def test_with_waffle_switch_turned_off(self):
"""
Verify that `get_utm_source_for_user` returns User's username when waffle switch
`use_company_name_as_utm_source_value` is turned off.
"""
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
def test_with_missing_lms_url(self):
"""
Verify that `get_utm_source_for_user` returns default value if
`Partner.lms_url` is not set in the database.
"""
# Remove lms_url from partner.
self.partner.lms_url = ''
self.partner.save()
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate
def test_when_api_response_is_not_valid(self):
"""
Verify that `get_utm_source_for_user` returns default value if
LMS API does not return a valid response.
"""
self.mock_api_access_request(self.partner.lms_url, status=400)
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate
def test_get_utm_source_for_user(self):
"""
Verify that `get_utm_source_for_user` returns correct value.
"""
company_name = 'Test Company'
expected_utm_source = slugify('{} {}'.format(self.user.username, company_name))
self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides={'company_name': company_name},
)
assert get_utm_source_for_user(self.partner, self.user) == expected_utm_source
import hashlib
import logging
import six
logger = logging.getLogger(__name__)
......@@ -38,3 +40,27 @@ def get_query_param(request, name):
return
return cast2int(request.query_params.get(name), name)
def get_cache_key(**kwargs):
"""
Get MD5 encoded cache key for given arguments.
Here is the format of key before MD5 encryption.
key1:value1__key2:value2 ...
Example:
>>> get_cache_key(site_domain="example.com", resource="catalogs")
# Here is key format for above call
# "site_domain:example.com__resource:catalogs"
a54349175618ff1659dee0978e3149ca
Arguments:
**kwargs: Key word arguments that need to be present in cache key.
Returns:
An MD5 encoded key uniquely identified by the key word arguments.
"""
key = '__'.join(['{}:{}'.format(item, value) for item, value in six.iteritems(kwargs)])
return hashlib.md5(key.encode('utf-8')).hexdigest()
......@@ -7,6 +7,7 @@ import pytest
import pytz
import responses
from django.contrib.auth import get_user_model
from django.core.cache import cache
from rest_framework.reverse import reverse
from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
......@@ -17,9 +18,12 @@ from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import Course
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, SeatFactory
from course_discovery.conftest import get_course_run_states
User = get_user_model()
STATES, AVAILABLE_STATES = get_course_run_states()
@ddt.ddt
@pytest.mark.usefixtures('course_run_states')
......@@ -42,6 +46,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
)
self.course = self.course_run.course
self.refresh_index()
cache.clear()
def assert_catalog_created(self, **headers):
name = 'The Kitchen Sink'
......@@ -148,7 +153,10 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
self.assertEqual(response.status_code, 400)
self.assertEqual(User.objects.count(), original_user_count)
def test_courses(self):
@ddt.data(
*STATES()
)
def test_courses(self, state):
"""
Verify the endpoint returns the list of available courses contained in
the catalog, and that courses appearing in the response always have at
......@@ -156,7 +164,6 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
"""
url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id})
for state in self.states():
Course.objects.all().delete()
course_run = CourseRunFactory(course__title='ABC Test Course')
......@@ -165,14 +172,14 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
course_run.save()
if state in self.available_states:
if state in AVAILABLE_STATES:
course = course_run.course
# This run has no seats, but we still expect its parent course
# to be included.
filtered_course_run = CourseRunFactory(course=course)
with self.assertNumQueries(18):
with self.assertNumQueries(20):
response = self.client.get(url)
assert response.status_code == 200
......@@ -217,7 +224,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id})
with self.assertNumQueries(18):
with self.assertNumQueries(20):
response = self.client.get(url)
course_run = self.serialize_catalog_flat_course_run(self.course_run)
......
......@@ -2,6 +2,8 @@ import datetime
import ddt
import pytz
from django.core.cache import cache
from django.db.models.functions import Lower
from rest_framework.reverse import reverse
......@@ -24,12 +26,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.course = CourseFactory(partner=self.partner)
cache.clear()
def test_get(self):
""" Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(20):
with self.assertNumQueries(21):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course))
......@@ -38,7 +41,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(13):
with self.assertNumQueries(14):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), [])
......@@ -51,7 +54,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1'
with self.assertNumQueries(24):
with self.assertNumQueries(25):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(
......@@ -187,7 +190,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
with self.assertNumQueries(26):
with self.assertNumQueries(27):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(
......@@ -203,7 +206,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(39):
with self.assertNumQueries(41):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......@@ -214,7 +217,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(39):
with self.assertNumQueries(40):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
import urllib.parse
import pytest
from django.core.cache import cache
from django.test import RequestFactory
from django.urls import reverse
......@@ -41,6 +42,7 @@ class TestProgramViewSet(SerializationMixin):
self.django_assert_num_queries = django_assert_num_queries
self.partner = partner
self.request = request
cache.clear()
def create_program(self):
organizations = [OrganizationFactory(partner=self.partner)]
......@@ -86,7 +88,7 @@ class TestProgramViewSet(SerializationMixin):
def test_retrieve(self, django_assert_num_queries):
""" Verify the endpoint returns the details for a single program. """
program = self.create_program()
with django_assert_num_queries(39):
with django_assert_num_queries(40):
response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
......@@ -112,7 +114,7 @@ class TestProgramViewSet(SerializationMixin):
partner=self.partner)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
with django_assert_num_queries(28):
with django_assert_num_queries(29):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
assert course_list == list(program.courses.all()) # pylint: disable=no-member
......@@ -148,7 +150,7 @@ class TestProgramViewSet(SerializationMixin):
""" Verify the endpoint returns a list of all programs. """
expected = [self.create_program() for __ in range(3)]
expected.reverse()
self.assert_list_results(self.list_path, expected, 14)
self.assert_list_results(self.list_path, expected, 15)
# Verify that repeated list requests use the cache.
self.assert_list_results(self.list_path, expected, 4)
......@@ -272,7 +274,7 @@ class TestProgramViewSet(SerializationMixin):
program.marketing_slug = SLUG
program.save()
self.assert_list_results(url, [program], 14)
self.assert_list_results(url, [program], 15)
def test_list_exclude_utm(self):
""" Verify the endpoint returns marketing URLs without UTM parameters. """
......
......@@ -40,7 +40,7 @@ class CurrencyAdmin(admin.ModelAdmin):
class PartnerAdmin(admin.ModelAdmin):
fieldsets = (
(None, {
'fields': ('name', 'short_code', 'studio_url', 'site')
'fields': ('name', 'short_code', 'lms_url', 'studio_url', 'site')
}),
(_('OpenID Connect'), {
'description': _(
......
"""
API Client for LMS.
"""
import logging
from django.core.cache import cache
from edx_rest_api_client.client import EdxRestApiClient
from edx_rest_api_client.exceptions import SlumberBaseException
from requests.exceptions import ConnectionError, Timeout # pylint: disable=redefined-builtin
from course_discovery.apps.api.utils import get_cache_key
logger = logging.getLogger(__name__)
class LMSAPIClient(object):
"""
API Client for communication between discovery and LMS.
"""
def __init__(self, site, user):
self.site = site
self.user = user
self.client = EdxRestApiClient(self.site.partner.lms_url, oauth_access_token=user.access_token)
def get_api_access_request(self, user):
"""
Get API Access Requests made by the given user.
Arguments:
user (User): Django User.
Returns:
(dict): API Access requests made by the given user.
Examples:
>> user = User.objects.get(username='staff')
>> lms_api_client.get_api_access_requests(user)
{
"id": 1,
"created": "2017-09-25T08:37:05.872566Z",
"modified": "2017-09-25T08:37:47.412496Z",
"user": 5,
"status": "approved",
"website": "https://example.com/",
"reason": "Example Reason",
"company_name": "Example Inc",
"company_address": "Example Address",
"site": 1,
"contacted": True
}
"""
resource = 'api-admin/api/v1/api_access_request/'
query_parameters = {}
# Since Staff user has access to all API Access Requests and we want to limit the response.
# So, we will filter by username for staff users.
if user.is_staff:
query_parameters = {
'user__username': user.username
}
cache_key = get_cache_key(username=user.username, resource=resource)
api_access_request = cache.get(cache_key)
if not api_access_request:
try:
results = getattr(self.client, resource).get(**query_parameters)['results']
if len(results) > 1:
logger.warning(
'Multiple APIAccessRequest models returned from LMS API for user [%s].',
user.username,
)
api_access_request = results[0]
cache.set(cache_key, api_access_request, 60 * 60)
except (SlumberBaseException, ConnectionError, Timeout):
logger.exception('Failed to fetch API Access Request from LMS for user "%s".', user.username)
except (IndexError, KeyError):
# This should not happen as user must always have at-least one api-access-request.
logger.exception('APIAccessRequest model not found for user [%s].', user.username)
return api_access_request
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2017-09-28 10:50
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0005_auto_20170830_1246'),
]
operations = [
migrations.AddField(
model_name='partner',
name='lms_url',
field=models.URLField(blank=True, max_length=255, null=True, verbose_name='LMS URL'),
),
]
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2017-10-04 11:33
from __future__ import unicode_literals
from django.db import migrations
SWITCH = 'use_company_name_as_utm_source_value'
def create_switch(apps, schema_editor):
"""Create the use_company_name_as_utm_source_value switch."""
Switch = apps.get_model('waffle', 'Switch')
Switch.objects.get_or_create(name=SWITCH, defaults={'active': False})
def delete_switch(apps, schema_editor):
"""Delete the use_company_name_as_utm_source_value switch."""
Switch = apps.get_model('waffle', 'Switch')
Switch.objects.filter(name=SWITCH).delete()
class Migration(migrations.Migration):
dependencies = [
('core', '0006_partner_lms_url'),
('waffle', '0001_initial'),
]
operations = [
migrations.RunPython(create_switch, reverse_code=delete_switch),
]
......@@ -84,6 +84,7 @@ class Partner(TimeStampedModel):
oidc_secret = models.CharField(max_length=255, null=True, verbose_name=_('OpenID Connect Secret'))
studio_url = models.URLField(max_length=255, null=True, blank=True, verbose_name=_('Studio URL'))
site = models.OneToOneField(Site, on_delete=models.PROTECT)
lms_url = models.URLField(max_length=255, null=True, blank=True, verbose_name=_('LMS URL'))
def __str__(self):
return self.name
......
......@@ -56,6 +56,7 @@ class PartnerFactory(factory.DjangoModelFactory):
oidc_secret = factory.Faker('sha256')
site = factory.SubFactory(SiteFactory)
studio_url = factory.Faker('url')
lms_url = factory.Faker('url')
class Meta(object):
model = Partner
import json
import logging
import pytest
import responses
from django.conf import settings
from haystack import connections as haystack_connections
......@@ -36,3 +39,59 @@ class ElasticsearchTestMixin(object):
for course in program.courses.all():
index.update_object(course)
self.reindex_course_runs(course)
class LMSAPIClientMixin(object):
def mock_api_access_request(self, lms_url, status=200, api_access_request_overrides=None):
"""
Mock the api access requests endpoint response of the LMS.
"""
data = {
'count': 2,
'num_pages': 1,
'current_page': 1,
'results':
[
dict(
{
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 1,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
},
**(api_access_request_overrides or {})
)
],
'next': None,
'start': 0,
'previous': None
}
responses.add(
responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/',
body=json.dumps(data),
content_type='application/json',
status=status
)
def mock_api_access_request_with_invalid_data(self, lms_url, status=200, response_overrides=None):
"""
Mock the api access requests endpoint response of the LMS.
"""
data = response_overrides or {}
responses.add(
responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/',
body=json.dumps(data),
content_type='application/json',
status=status
)
import logging
import responses
from django.test import TestCase
from course_discovery.apps.core.api_client import lms
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import LMSAPIClientMixin
from course_discovery.apps.core.tests.utils import MockLoggingHandler
class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
@classmethod
def setUpClass(cls):
super(TestLMSAPIClient, cls).setUpClass()
logger = logging.getLogger(lms.__name__)
cls.log_handler = MockLoggingHandler(level='DEBUG')
logger.addHandler(cls.log_handler)
cls.log_messages = cls.log_handler.messages
def setUp(self):
super(TestLMSAPIClient, self).setUp()
# Reset mock logger for each test.
self.log_handler.reset()
self.user = UserFactory.create()
self.partner = PartnerFactory.create()
self.lms = lms.LMSAPIClient(self.partner.site, self.user)
self.response = {
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 1,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
}
@responses.activate
def test_get_api_access_request(self):
"""
Verify that `get_api_access_request` returns correct value.
"""
self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides=self.response
)
assert self.lms.get_api_access_request(self.user) == self.response
@responses.activate
def test_get_api_access_request_with_404_error(self):
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
self.mock_api_access_request(
self.partner.lms_url, status=404
)
assert self.lms.get_api_access_request(self.user) is None
assert 'Failed to fetch API Access Request from LMS for user "%s".' % self.user.username in \
self.log_messages['error']
@responses.activate
def test_get_api_access_request_with_empty_response(self):
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url
)
assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error']
@responses.activate
def test_get_api_access_request_with_invalid_response(self):
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
# API response without proper paginated structure.
# Following is an invalid response.
sample_invalid_response = {
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 5,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Example Inc',
'company_address': 'Example Address',
'site': 1,
'contacted': True
}
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_invalid_response
)
assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error']
@responses.activate
def test_get_api_access_request_with_multiple_records(self):
"""
Verify that `get_api_access_request` logs a warning message and returns the first result
if endpoint returns multiple api-access-requests for a user.
"""
# API response without proper paginated structure.
# Following is an invalid response.
sample_response_with_multiple_users = {
'count': 2,
'num_pages': 1,
'current_page': 1,
'results':
[
{
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 1,
'status': 'declined',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
},
{
'id': 2,
'created': '2017-10-25T08:37:05.872566Z',
'modified': '2017-10-25T08:37:47.412496Z',
'user': 1,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
},
],
'next': None,
'start': 0,
'previous': None
}
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_response_with_multiple_users
)
assert self.lms.get_api_access_request(self.user)['company_name'] == 'Test Company'
assert 'Multiple APIAccessRequest models returned from LMS API for user [%s].' % self.user.username in \
self.log_messages['warning']
import json
import logging
import math
from urllib.parse import parse_qs, urlparse
......@@ -92,3 +93,40 @@ def mock_jpeg_callback():
return 200, {}, image_stream.getvalue()
return request_callback
class MockLoggingHandler(logging.Handler):
"""
Mock logging handler to check for expected logs.
Messages are available from an instance's ``messages`` dict, in order, indexed by
a lowercase log level string (e.g., 'debug', 'info', etc.).
"""
def __init__(self, *args, **kwargs):
self.messages = {
'debug': [],
'info': [],
'warning': [],
'error': [],
'critical': [],
}
super(MockLoggingHandler, self).__init__(*args, **kwargs)
def emit(self, record):
"""
Store a message from ``record`` in the instance's ``messages`` dict.
"""
self.acquire()
try:
self.messages[record.levelname.lower()].append(record.getMessage())
finally:
self.release()
def reset(self):
self.acquire()
try:
for message_list in self.messages.values():
message_list.clear()
finally:
self.release()
......@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n"
"POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
......@@ -194,6 +194,10 @@ msgid "Studio URL"
msgstr ""
#: apps/core/models.py
msgid "LMS URL"
msgstr ""
#: apps/core/models.py
msgid "Partner"
msgstr ""
......
......@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n"
"POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
......
......@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n"
"POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
......@@ -235,6 +235,10 @@ msgid "Studio URL"
msgstr "Stüdïö ÛRL Ⱡ'σяєм ιρѕυм ∂σłσ#"
#: apps/core/models.py
msgid "LMS URL"
msgstr "LMS ÛRL Ⱡ'σяєм ιρѕυм #"
#: apps/core/models.py
msgid "Partner"
msgstr "Pärtnér Ⱡ'σяєм ιρѕυм #"
......
......@@ -7,7 +7,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n"
"POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
......
......@@ -15,6 +15,14 @@ def course_run_states(request):
pytest fixture for providing test classes with attributes necessary to create
and test CourseRuns in all states affecting availability.
"""
# Set class attributes on the invoking test context.
request.cls.states, request.cls.available_states = get_course_run_states()
def get_course_run_states():
"""
Utility method to get course_run_states and available_states.
"""
now = datetime.datetime.now(pytz.UTC)
past = now - datetime.timedelta(days=30)
future = now + datetime.timedelta(days=30)
......@@ -126,6 +134,4 @@ def course_run_states(request):
]
]
# Set class attributes on the invoking test context.
request.cls.states = partial(product, *states)
request.cls.available_states = list(product(*available_states))
return partial(product, *states), list(product(*available_states))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment