Commit 295fa22d by Saleem Latif Committed by Douglas Hall

Change UTM Source value in API results returned by discovery API

parent ad1dca56
...@@ -4,9 +4,11 @@ import json ...@@ -4,9 +4,11 @@ import json
from urllib.parse import urlencode from urllib.parse import urlencode
import pytz import pytz
import waffle
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db.models.query import Prefetch from django.db.models.query import Prefetch
from django.utils.text import slugify
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from drf_haystack.serializers import HaystackFacetSerializer, HaystackSerializer from drf_haystack.serializers import HaystackFacetSerializer, HaystackSerializer
from rest_framework import serializers from rest_framework import serializers
...@@ -15,6 +17,7 @@ from taggit_serializer.serializers import TaggitSerializer, TagListSerializerFie ...@@ -15,6 +17,7 @@ from taggit_serializer.serializers import TaggitSerializer, TagListSerializerFie
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.api_client.lms import LMSAPIClient
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import (FAQ, CorporateEndorsement, Course, CourseRun, Endorsement, from course_discovery.apps.course_metadata.models import (FAQ, CorporateEndorsement, Course, CourseRun, Endorsement,
Image, Organization, Person, PersonSocialNetwork, PersonWork, Image, Organization, Person, PersonSocialNetwork, PersonWork,
...@@ -114,11 +117,12 @@ SELECT_RELATED_FIELDS = { ...@@ -114,11 +117,12 @@ SELECT_RELATED_FIELDS = {
} }
def get_marketing_url_for_user(user, marketing_url, exclude_utm=False): def get_marketing_url_for_user(partner, user, marketing_url, exclude_utm=False):
""" """
Return the given marketing URL with affiliate query parameters for the user. Return the given marketing URL with affiliate query parameters for the user.
Arguments: Arguments:
partner (Partner): Partner instance containing information.
user (User): Used to construct UTM query parameters. user (User): Used to construct UTM query parameters.
marketing_url (str | None): Base URL to which UTM parameters may be appended. marketing_url (str | None): Base URL to which UTM parameters may be appended.
...@@ -134,12 +138,36 @@ def get_marketing_url_for_user(user, marketing_url, exclude_utm=False): ...@@ -134,12 +138,36 @@ def get_marketing_url_for_user(user, marketing_url, exclude_utm=False):
return marketing_url return marketing_url
else: else:
params = urlencode({ params = urlencode({
'utm_source': user.username, 'utm_source': get_utm_source_for_user(partner, user),
'utm_medium': user.referral_tracking_id, 'utm_medium': user.referral_tracking_id,
}) })
return '{url}?{params}'.format(url=marketing_url, params=params) return '{url}?{params}'.format(url=marketing_url, params=params)
def get_utm_source_for_user(partner, user):
"""
Return the utm source for the user.
Arguments:
partner (Partner): Partner instance containing information.
user (User): Used to construct UTM query parameters.
Returns:
str: username and company name slugified and combined together.
"""
utm_source = user.username
# If use_company_name_as_utm_source_value is enabled and lms_url value is set then
# use company name from API Access Request as utm_source.
if waffle.switch_is_active('use_company_name_as_utm_source_value') and partner.lms_url:
lms = LMSAPIClient(partner.site, user)
api_access_request = lms.get_api_access_request(user)
if api_access_request:
utm_source = '{} {}'.format(utm_source, api_access_request['company_name'])
return slugify(utm_source)
class TimestampModelSerializer(serializers.ModelSerializer): class TimestampModelSerializer(serializers.ModelSerializer):
"""Serializer for timestamped models.""" """Serializer for timestamped models."""
modified = serializers.DateTimeField() modified = serializers.DateTimeField()
...@@ -444,6 +472,7 @@ class MinimalCourseRunSerializer(TimestampModelSerializer): ...@@ -444,6 +472,7 @@ class MinimalCourseRunSerializer(TimestampModelSerializer):
def get_marketing_url(self, obj): def get_marketing_url(self, obj):
return get_marketing_url_for_user( return get_marketing_url_for_user(
obj.course.partner,
self.context['request'].user, self.context['request'].user,
obj.marketing_url, obj.marketing_url,
exclude_utm=self.context.get('exclude_utm') exclude_utm=self.context.get('exclude_utm')
...@@ -585,6 +614,7 @@ class CourseSerializer(MinimalCourseSerializer): ...@@ -585,6 +614,7 @@ class CourseSerializer(MinimalCourseSerializer):
def get_marketing_url(self, obj): def get_marketing_url(self, obj):
return get_marketing_url_for_user( return get_marketing_url_for_user(
obj.partner,
self.context['request'].user, self.context['request'].user,
obj.marketing_url, obj.marketing_url,
exclude_utm=self.context.get('exclude_utm') exclude_utm=self.context.get('exclude_utm')
......
...@@ -6,35 +6,46 @@ from urllib.parse import urlencode ...@@ -6,35 +6,46 @@ from urllib.parse import urlencode
import ddt import ddt
import pytest import pytest
import pytz import pytz
import responses
from django.test import TestCase from django.test import TestCase
from django.utils.text import slugify
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from waffle.models import Switch
from waffle.testutils import override_switch
from course_discovery.apps.api.fields import ImageField, StdImageSerializerField from course_discovery.apps.api.fields import ImageField, StdImageSerializerField
from course_discovery.apps.api.serializers import ( from course_discovery.apps.api.serializers import (AffiliateWindowSerializer, CatalogSerializer,
AffiliateWindowSerializer, CatalogSerializer, ContainedCourseRunsSerializer, ContainedCoursesSerializer, ContainedCourseRunsSerializer, ContainedCoursesSerializer,
CorporateEndorsementSerializer, CourseRunSearchSerializer, CourseRunSerializer, CourseRunWithProgramsSerializer, CorporateEndorsementSerializer, CourseRunSearchSerializer,
CourseSearchSerializer, CourseSerializer, CourseWithProgramsSerializer, EndorsementSerializer, FAQSerializer, CourseRunSerializer, CourseRunWithProgramsSerializer,
FlattenedCourseRunWithCourseSerializer, ImageSerializer, MinimalCourseRunSerializer, MinimalCourseSerializer, CourseSearchSerializer, CourseSerializer,
MinimalOrganizationSerializer, MinimalProgramCourseSerializer, MinimalProgramSerializer, NestedProgramSerializer, CourseWithProgramsSerializer, EndorsementSerializer, FAQSerializer,
OrganizationSerializer, PersonSerializer, PositionSerializer, PrerequisiteSerializer, ProgramSearchSerializer, FlattenedCourseRunWithCourseSerializer, ImageSerializer,
ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer, MinimalCourseRunSerializer, MinimalCourseSerializer,
TypeaheadProgramSearchSerializer, VideoSerializer MinimalOrganizationSerializer, MinimalProgramCourseSerializer,
) MinimalProgramSerializer, NestedProgramSerializer,
OrganizationSerializer, PersonSerializer, PositionSerializer,
PrerequisiteSerializer, ProgramSearchSerializer, ProgramSerializer,
ProgramTypeSerializer, SeatSerializer, SubjectSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer,
VideoSerializer, get_utm_source_for_user)
from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin, LMSAPIClientMixin
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
from course_discovery.apps.course_metadata.tests.factories import ( from course_discovery.apps.course_metadata.tests.factories import (CorporateEndorsementFactory, CourseFactory,
CorporateEndorsementFactory, CourseFactory, CourseRunFactory, EndorsementFactory, ExpectedLearningItemFactory, CourseRunFactory, EndorsementFactory,
ImageFactory, JobOutlookItemFactory, OrganizationFactory, PersonFactory, PositionFactory, PrerequisiteFactory, ExpectedLearningItemFactory, ImageFactory,
ProgramFactory, ProgramTypeFactory, SeatFactory, SeatTypeFactory, SubjectFactory, VideoFactory JobOutlookItemFactory, OrganizationFactory,
) PersonFactory, PositionFactory, PrerequisiteFactory,
ProgramFactory, ProgramTypeFactory, SeatFactory,
SeatTypeFactory, SubjectFactory, VideoFactory)
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
...@@ -1384,3 +1395,57 @@ class TestTypeaheadProgramSearchSerializer: ...@@ -1384,3 +1395,57 @@ class TestTypeaheadProgramSearchSerializer:
result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0] result = SearchQuerySet().models(Program).filter(uuid=program.uuid)[0]
serializer = TypeaheadProgramSearchSerializer(result) serializer = TypeaheadProgramSearchSerializer(result)
return serializer return serializer
class TestGetUTMSourceForUser(LMSAPIClientMixin, TestCase):
def setUp(self):
super(TestGetUTMSourceForUser, self).setUp()
self.switch, __ = Switch.objects.update_or_create(
name='use_company_name_as_utm_source_value', defaults={'active': True}
)
self.user = UserFactory.create()
self.partner = PartnerFactory.create()
@override_switch('use_company_name_as_utm_source_value', active=False)
def test_with_waffle_switch_turned_off(self):
"""
Verify that `get_utm_source_for_user` returns User's username when waffle switch
`use_company_name_as_utm_source_value` is turned off.
"""
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
def test_with_missing_lms_url(self):
"""
Verify that `get_utm_source_for_user` returns default value if
`Partner.lms_url` is not set in the database.
"""
# Remove lms_url from partner.
self.partner.lms_url = ''
self.partner.save()
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate
def test_when_api_response_is_not_valid(self):
"""
Verify that `get_utm_source_for_user` returns default value if
LMS API does not return a valid response.
"""
self.mock_api_access_request(self.partner.lms_url, status=400)
assert get_utm_source_for_user(self.partner, self.user) == self.user.username
@responses.activate
def test_get_utm_source_for_user(self):
"""
Verify that `get_utm_source_for_user` returns correct value.
"""
company_name = 'Test Company'
expected_utm_source = slugify('{} {}'.format(self.user.username, company_name))
self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides={'company_name': company_name},
)
assert get_utm_source_for_user(self.partner, self.user) == expected_utm_source
import hashlib
import logging import logging
import six
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -38,3 +40,27 @@ def get_query_param(request, name): ...@@ -38,3 +40,27 @@ def get_query_param(request, name):
return return
return cast2int(request.query_params.get(name), name) return cast2int(request.query_params.get(name), name)
def get_cache_key(**kwargs):
"""
Get MD5 encoded cache key for given arguments.
Here is the format of key before MD5 encryption.
key1:value1__key2:value2 ...
Example:
>>> get_cache_key(site_domain="example.com", resource="catalogs")
# Here is key format for above call
# "site_domain:example.com__resource:catalogs"
a54349175618ff1659dee0978e3149ca
Arguments:
**kwargs: Key word arguments that need to be present in cache key.
Returns:
An MD5 encoded key uniquely identified by the key word arguments.
"""
key = '__'.join(['{}:{}'.format(item, value) for item, value in six.iteritems(kwargs)])
return hashlib.md5(key.encode('utf-8')).hexdigest()
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import pytz import pytz
import responses import responses
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.cache import cache
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
...@@ -17,9 +18,12 @@ from course_discovery.apps.core.tests.factories import UserFactory ...@@ -17,9 +18,12 @@ from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.models import Course from course_discovery.apps.course_metadata.models import Course
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, SeatFactory from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, SeatFactory
from course_discovery.conftest import get_course_run_states
User = get_user_model() User = get_user_model()
STATES, AVAILABLE_STATES = get_course_run_states()
@ddt.ddt @ddt.ddt
@pytest.mark.usefixtures('course_run_states') @pytest.mark.usefixtures('course_run_states')
...@@ -42,6 +46,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -42,6 +46,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
) )
self.course = self.course_run.course self.course = self.course_run.course
self.refresh_index() self.refresh_index()
cache.clear()
def assert_catalog_created(self, **headers): def assert_catalog_created(self, **headers):
name = 'The Kitchen Sink' name = 'The Kitchen Sink'
...@@ -148,7 +153,10 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -148,7 +153,10 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(User.objects.count(), original_user_count) self.assertEqual(User.objects.count(), original_user_count)
def test_courses(self): @ddt.data(
*STATES()
)
def test_courses(self, state):
""" """
Verify the endpoint returns the list of available courses contained in Verify the endpoint returns the list of available courses contained in
the catalog, and that courses appearing in the response always have at the catalog, and that courses appearing in the response always have at
...@@ -156,39 +164,38 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -156,39 +164,38 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
""" """
url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id}) url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id})
for state in self.states(): Course.objects.all().delete()
Course.objects.all().delete()
course_run = CourseRunFactory(course__title='ABC Test Course') course_run = CourseRunFactory(course__title='ABC Test Course')
for function in state: for function in state:
function(course_run) function(course_run)
course_run.save() course_run.save()
if state in self.available_states: if state in AVAILABLE_STATES:
course = course_run.course course = course_run.course
# This run has no seats, but we still expect its parent course # This run has no seats, but we still expect its parent course
# to be included. # to be included.
filtered_course_run = CourseRunFactory(course=course) filtered_course_run = CourseRunFactory(course=course)
with self.assertNumQueries(18): with self.assertNumQueries(20):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
# Emulate prefetching behavior. # Emulate prefetching behavior.
filtered_course_run.delete() filtered_course_run.delete()
assert response.data['results'] == self.serialize_catalog_course([course], many=True) assert response.data['results'] == self.serialize_catalog_course([course], many=True)
# Any course appearing in the response must have at least one serialized run. # Any course appearing in the response must have at least one serialized run.
assert len(response.data['results'][0]['course_runs']) > 0 assert len(response.data['results'][0]['course_runs']) > 0
else: else:
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
assert response.data['results'] == [] assert response.data['results'] == []
def test_contains_for_course_key(self): def test_contains_for_course_key(self):
""" """
...@@ -217,7 +224,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -217,7 +224,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id}) url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id})
with self.assertNumQueries(18): with self.assertNumQueries(20):
response = self.client.get(url) response = self.client.get(url)
course_run = self.serialize_catalog_flat_course_run(self.course_run) course_run = self.serialize_catalog_flat_course_run(self.course_run)
......
...@@ -2,6 +2,8 @@ import datetime ...@@ -2,6 +2,8 @@ import datetime
import ddt import ddt
import pytz import pytz
from django.core.cache import cache
from django.db.models.functions import Lower from django.db.models.functions import Lower
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
...@@ -24,12 +26,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -24,12 +26,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
self.request.user = self.user self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
self.course = CourseFactory(partner=self.partner) self.course = CourseFactory(partner=self.partner)
cache.clear()
def test_get(self): def test_get(self):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(20): with self.assertNumQueries(21):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.data, self.serialize_course(self.course))
...@@ -38,7 +41,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -38,7 +41,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """ """ Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(13): with self.assertNumQueries(14):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), []) self.assertEqual(response.data.get('programs'), [])
...@@ -51,7 +54,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -51,7 +54,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1' url += '?include_deleted_programs=1'
with self.assertNumQueries(24): with self.assertNumQueries(25):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual( self.assertEqual(
...@@ -187,7 +190,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -187,7 +190,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list') url = reverse('api:v1:course-list')
with self.assertNumQueries(26): with self.assertNumQueries(27):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual( self.assertListEqual(
...@@ -203,7 +206,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -203,7 +206,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query) url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(39): with self.assertNumQueries(41):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
...@@ -214,7 +217,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -214,7 +217,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
keys = ','.join([course.key for course in courses]) keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(39): with self.assertNumQueries(40):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
import urllib.parse import urllib.parse
import pytest import pytest
from django.core.cache import cache
from django.test import RequestFactory from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
...@@ -41,6 +42,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -41,6 +42,7 @@ class TestProgramViewSet(SerializationMixin):
self.django_assert_num_queries = django_assert_num_queries self.django_assert_num_queries = django_assert_num_queries
self.partner = partner self.partner = partner
self.request = request self.request = request
cache.clear()
def create_program(self): def create_program(self):
organizations = [OrganizationFactory(partner=self.partner)] organizations = [OrganizationFactory(partner=self.partner)]
...@@ -86,7 +88,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -86,7 +88,7 @@ class TestProgramViewSet(SerializationMixin):
def test_retrieve(self, django_assert_num_queries): def test_retrieve(self, django_assert_num_queries):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
program = self.create_program() program = self.create_program()
with django_assert_num_queries(39): with django_assert_num_queries(40):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed # property does not have the right values while being indexed
del program._course_run_weeks_to_complete del program._course_run_weeks_to_complete
...@@ -112,7 +114,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -112,7 +114,7 @@ class TestProgramViewSet(SerializationMixin):
partner=self.partner) partner=self.partner)
# property does not have the right values while being indexed # property does not have the right values while being indexed
del program._course_run_weeks_to_complete del program._course_run_weeks_to_complete
with django_assert_num_queries(28): with django_assert_num_queries(29):
response = self.assert_retrieve_success(program) response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program) assert response.data == self.serialize_program(program)
assert course_list == list(program.courses.all()) # pylint: disable=no-member assert course_list == list(program.courses.all()) # pylint: disable=no-member
...@@ -148,7 +150,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -148,7 +150,7 @@ class TestProgramViewSet(SerializationMixin):
""" Verify the endpoint returns a list of all programs. """ """ Verify the endpoint returns a list of all programs. """
expected = [self.create_program() for __ in range(3)] expected = [self.create_program() for __ in range(3)]
expected.reverse() expected.reverse()
self.assert_list_results(self.list_path, expected, 14) self.assert_list_results(self.list_path, expected, 15)
# Verify that repeated list requests use the cache. # Verify that repeated list requests use the cache.
self.assert_list_results(self.list_path, expected, 4) self.assert_list_results(self.list_path, expected, 4)
...@@ -272,7 +274,7 @@ class TestProgramViewSet(SerializationMixin): ...@@ -272,7 +274,7 @@ class TestProgramViewSet(SerializationMixin):
program.marketing_slug = SLUG program.marketing_slug = SLUG
program.save() program.save()
self.assert_list_results(url, [program], 14) self.assert_list_results(url, [program], 15)
def test_list_exclude_utm(self): def test_list_exclude_utm(self):
""" Verify the endpoint returns marketing URLs without UTM parameters. """ """ Verify the endpoint returns marketing URLs without UTM parameters. """
......
...@@ -40,7 +40,7 @@ class CurrencyAdmin(admin.ModelAdmin): ...@@ -40,7 +40,7 @@ class CurrencyAdmin(admin.ModelAdmin):
class PartnerAdmin(admin.ModelAdmin): class PartnerAdmin(admin.ModelAdmin):
fieldsets = ( fieldsets = (
(None, { (None, {
'fields': ('name', 'short_code', 'studio_url', 'site') 'fields': ('name', 'short_code', 'lms_url', 'studio_url', 'site')
}), }),
(_('OpenID Connect'), { (_('OpenID Connect'), {
'description': _( 'description': _(
......
"""
API Client for LMS.
"""
import logging
from django.core.cache import cache
from edx_rest_api_client.client import EdxRestApiClient
from edx_rest_api_client.exceptions import SlumberBaseException
from requests.exceptions import ConnectionError, Timeout # pylint: disable=redefined-builtin
from course_discovery.apps.api.utils import get_cache_key
logger = logging.getLogger(__name__)
class LMSAPIClient(object):
"""
API Client for communication between discovery and LMS.
"""
def __init__(self, site, user):
self.site = site
self.user = user
self.client = EdxRestApiClient(self.site.partner.lms_url, oauth_access_token=user.access_token)
def get_api_access_request(self, user):
"""
Get API Access Requests made by the given user.
Arguments:
user (User): Django User.
Returns:
(dict): API Access requests made by the given user.
Examples:
>> user = User.objects.get(username='staff')
>> lms_api_client.get_api_access_requests(user)
{
"id": 1,
"created": "2017-09-25T08:37:05.872566Z",
"modified": "2017-09-25T08:37:47.412496Z",
"user": 5,
"status": "approved",
"website": "https://example.com/",
"reason": "Example Reason",
"company_name": "Example Inc",
"company_address": "Example Address",
"site": 1,
"contacted": True
}
"""
resource = 'api-admin/api/v1/api_access_request/'
query_parameters = {}
# Since Staff user has access to all API Access Requests and we want to limit the response.
# So, we will filter by username for staff users.
if user.is_staff:
query_parameters = {
'user__username': user.username
}
cache_key = get_cache_key(username=user.username, resource=resource)
api_access_request = cache.get(cache_key)
if not api_access_request:
try:
results = getattr(self.client, resource).get(**query_parameters)['results']
if len(results) > 1:
logger.warning(
'Multiple APIAccessRequest models returned from LMS API for user [%s].',
user.username,
)
api_access_request = results[0]
cache.set(cache_key, api_access_request, 60 * 60)
except (SlumberBaseException, ConnectionError, Timeout):
logger.exception('Failed to fetch API Access Request from LMS for user "%s".', user.username)
except (IndexError, KeyError):
# This should not happen as user must always have at-least one api-access-request.
logger.exception('APIAccessRequest model not found for user [%s].', user.username)
return api_access_request
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2017-09-28 10:50
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0005_auto_20170830_1246'),
]
operations = [
migrations.AddField(
model_name='partner',
name='lms_url',
field=models.URLField(blank=True, max_length=255, null=True, verbose_name='LMS URL'),
),
]
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2017-10-04 11:33
from __future__ import unicode_literals
from django.db import migrations
SWITCH = 'use_company_name_as_utm_source_value'
def create_switch(apps, schema_editor):
"""Create the use_company_name_as_utm_source_value switch."""
Switch = apps.get_model('waffle', 'Switch')
Switch.objects.get_or_create(name=SWITCH, defaults={'active': False})
def delete_switch(apps, schema_editor):
"""Delete the use_company_name_as_utm_source_value switch."""
Switch = apps.get_model('waffle', 'Switch')
Switch.objects.filter(name=SWITCH).delete()
class Migration(migrations.Migration):
dependencies = [
('core', '0006_partner_lms_url'),
('waffle', '0001_initial'),
]
operations = [
migrations.RunPython(create_switch, reverse_code=delete_switch),
]
...@@ -84,6 +84,7 @@ class Partner(TimeStampedModel): ...@@ -84,6 +84,7 @@ class Partner(TimeStampedModel):
oidc_secret = models.CharField(max_length=255, null=True, verbose_name=_('OpenID Connect Secret')) oidc_secret = models.CharField(max_length=255, null=True, verbose_name=_('OpenID Connect Secret'))
studio_url = models.URLField(max_length=255, null=True, blank=True, verbose_name=_('Studio URL')) studio_url = models.URLField(max_length=255, null=True, blank=True, verbose_name=_('Studio URL'))
site = models.OneToOneField(Site, on_delete=models.PROTECT) site = models.OneToOneField(Site, on_delete=models.PROTECT)
lms_url = models.URLField(max_length=255, null=True, blank=True, verbose_name=_('LMS URL'))
def __str__(self): def __str__(self):
return self.name return self.name
......
...@@ -56,6 +56,7 @@ class PartnerFactory(factory.DjangoModelFactory): ...@@ -56,6 +56,7 @@ class PartnerFactory(factory.DjangoModelFactory):
oidc_secret = factory.Faker('sha256') oidc_secret = factory.Faker('sha256')
site = factory.SubFactory(SiteFactory) site = factory.SubFactory(SiteFactory)
studio_url = factory.Faker('url') studio_url = factory.Faker('url')
lms_url = factory.Faker('url')
class Meta(object): class Meta(object):
model = Partner model = Partner
import json
import logging import logging
import pytest import pytest
import responses
from django.conf import settings from django.conf import settings
from haystack import connections as haystack_connections from haystack import connections as haystack_connections
...@@ -36,3 +39,59 @@ class ElasticsearchTestMixin(object): ...@@ -36,3 +39,59 @@ class ElasticsearchTestMixin(object):
for course in program.courses.all(): for course in program.courses.all():
index.update_object(course) index.update_object(course)
self.reindex_course_runs(course) self.reindex_course_runs(course)
class LMSAPIClientMixin(object):
def mock_api_access_request(self, lms_url, status=200, api_access_request_overrides=None):
"""
Mock the api access requests endpoint response of the LMS.
"""
data = {
'count': 2,
'num_pages': 1,
'current_page': 1,
'results':
[
dict(
{
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 1,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
},
**(api_access_request_overrides or {})
)
],
'next': None,
'start': 0,
'previous': None
}
responses.add(
responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/',
body=json.dumps(data),
content_type='application/json',
status=status
)
def mock_api_access_request_with_invalid_data(self, lms_url, status=200, response_overrides=None):
"""
Mock the api access requests endpoint response of the LMS.
"""
data = response_overrides or {}
responses.add(
responses.GET,
lms_url.rstrip('/') + '/api-admin/api/v1/api_access_request/',
body=json.dumps(data),
content_type='application/json',
status=status
)
import logging
import responses
from django.test import TestCase
from course_discovery.apps.core.api_client import lms
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import LMSAPIClientMixin
from course_discovery.apps.core.tests.utils import MockLoggingHandler
class TestLMSAPIClient(LMSAPIClientMixin, TestCase):
@classmethod
def setUpClass(cls):
super(TestLMSAPIClient, cls).setUpClass()
logger = logging.getLogger(lms.__name__)
cls.log_handler = MockLoggingHandler(level='DEBUG')
logger.addHandler(cls.log_handler)
cls.log_messages = cls.log_handler.messages
def setUp(self):
super(TestLMSAPIClient, self).setUp()
# Reset mock logger for each test.
self.log_handler.reset()
self.user = UserFactory.create()
self.partner = PartnerFactory.create()
self.lms = lms.LMSAPIClient(self.partner.site, self.user)
self.response = {
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 1,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
}
@responses.activate
def test_get_api_access_request(self):
"""
Verify that `get_api_access_request` returns correct value.
"""
self.mock_api_access_request(
self.partner.lms_url, api_access_request_overrides=self.response
)
assert self.lms.get_api_access_request(self.user) == self.response
@responses.activate
def test_get_api_access_request_with_404_error(self):
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
self.mock_api_access_request(
self.partner.lms_url, status=404
)
assert self.lms.get_api_access_request(self.user) is None
assert 'Failed to fetch API Access Request from LMS for user "%s".' % self.user.username in \
self.log_messages['error']
@responses.activate
def test_get_api_access_request_with_empty_response(self):
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url
)
assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error']
@responses.activate
def test_get_api_access_request_with_invalid_response(self):
"""
Verify that `get_api_access_request` returns None when api_access_request
API endpoint is not available.
"""
# API response without proper paginated structure.
# Following is an invalid response.
sample_invalid_response = {
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 5,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Example Inc',
'company_address': 'Example Address',
'site': 1,
'contacted': True
}
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_invalid_response
)
assert self.lms.get_api_access_request(self.user) is None
assert 'APIAccessRequest model not found for user [%s].' % self.user.username in \
self.log_messages['error']
@responses.activate
def test_get_api_access_request_with_multiple_records(self):
"""
Verify that `get_api_access_request` logs a warning message and returns the first result
if endpoint returns multiple api-access-requests for a user.
"""
# API response without proper paginated structure.
# Following is an invalid response.
sample_response_with_multiple_users = {
'count': 2,
'num_pages': 1,
'current_page': 1,
'results':
[
{
'id': 1,
'created': '2017-09-25T08:37:05.872566Z',
'modified': '2017-09-25T08:37:47.412496Z',
'user': 1,
'status': 'declined',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
},
{
'id': 2,
'created': '2017-10-25T08:37:05.872566Z',
'modified': '2017-10-25T08:37:47.412496Z',
'user': 1,
'status': 'approved',
'website': 'https://example.com/',
'reason': 'Example Reason',
'company_name': 'Test Company',
'company_address': 'Example Address',
'site': 1,
'contacted': True
},
],
'next': None,
'start': 0,
'previous': None
}
self.mock_api_access_request_with_invalid_data(
self.partner.lms_url, response_overrides=sample_response_with_multiple_users
)
assert self.lms.get_api_access_request(self.user)['company_name'] == 'Test Company'
assert 'Multiple APIAccessRequest models returned from LMS API for user [%s].' % self.user.username in \
self.log_messages['warning']
import json import json
import logging
import math import math
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
...@@ -92,3 +93,40 @@ def mock_jpeg_callback(): ...@@ -92,3 +93,40 @@ def mock_jpeg_callback():
return 200, {}, image_stream.getvalue() return 200, {}, image_stream.getvalue()
return request_callback return request_callback
class MockLoggingHandler(logging.Handler):
"""
Mock logging handler to check for expected logs.
Messages are available from an instance's ``messages`` dict, in order, indexed by
a lowercase log level string (e.g., 'debug', 'info', etc.).
"""
def __init__(self, *args, **kwargs):
self.messages = {
'debug': [],
'info': [],
'warning': [],
'error': [],
'critical': [],
}
super(MockLoggingHandler, self).__init__(*args, **kwargs)
def emit(self, record):
"""
Store a message from ``record`` in the instance's ``messages`` dict.
"""
self.acquire()
try:
self.messages[record.levelname.lower()].append(record.getMessage())
finally:
self.release()
def reset(self):
self.acquire()
try:
for message_list in self.messages.values():
message_list.clear()
finally:
self.release()
...@@ -7,7 +7,7 @@ msgid "" ...@@ -7,7 +7,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n" "POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n" "Language-Team: LANGUAGE <LL@li.org>\n"
...@@ -194,6 +194,10 @@ msgid "Studio URL" ...@@ -194,6 +194,10 @@ msgid "Studio URL"
msgstr "" msgstr ""
#: apps/core/models.py #: apps/core/models.py
msgid "LMS URL"
msgstr ""
#: apps/core/models.py
msgid "Partner" msgid "Partner"
msgstr "" msgstr ""
......
...@@ -7,7 +7,7 @@ msgid "" ...@@ -7,7 +7,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n" "POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n" "Language-Team: LANGUAGE <LL@li.org>\n"
......
...@@ -7,7 +7,7 @@ msgid "" ...@@ -7,7 +7,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n" "POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n" "Language-Team: LANGUAGE <LL@li.org>\n"
...@@ -235,6 +235,10 @@ msgid "Studio URL" ...@@ -235,6 +235,10 @@ msgid "Studio URL"
msgstr "Stüdïö ÛRL Ⱡ'σяєм ιρѕυм ∂σłσ#" msgstr "Stüdïö ÛRL Ⱡ'σяєм ιρѕυм ∂σłσ#"
#: apps/core/models.py #: apps/core/models.py
msgid "LMS URL"
msgstr "LMS ÛRL Ⱡ'σяєм ιρѕυм #"
#: apps/core/models.py
msgid "Partner" msgid "Partner"
msgstr "Pärtnér Ⱡ'σяєм ιρѕυм #" msgstr "Pärtnér Ⱡ'σяєм ιρѕυм #"
......
...@@ -7,7 +7,7 @@ msgid "" ...@@ -7,7 +7,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-10-06 17:23+0000\n" "POT-Creation-Date: 2017-10-12 11:27+0500\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n" "Language-Team: LANGUAGE <LL@li.org>\n"
......
...@@ -15,6 +15,14 @@ def course_run_states(request): ...@@ -15,6 +15,14 @@ def course_run_states(request):
pytest fixture for providing test classes with attributes necessary to create pytest fixture for providing test classes with attributes necessary to create
and test CourseRuns in all states affecting availability. and test CourseRuns in all states affecting availability.
""" """
# Set class attributes on the invoking test context.
request.cls.states, request.cls.available_states = get_course_run_states()
def get_course_run_states():
"""
Utility method to get course_run_states and available_states.
"""
now = datetime.datetime.now(pytz.UTC) now = datetime.datetime.now(pytz.UTC)
past = now - datetime.timedelta(days=30) past = now - datetime.timedelta(days=30)
future = now + datetime.timedelta(days=30) future = now + datetime.timedelta(days=30)
...@@ -126,6 +134,4 @@ def course_run_states(request): ...@@ -126,6 +134,4 @@ def course_run_states(request):
] ]
] ]
# Set class attributes on the invoking test context. return partial(product, *states), list(product(*available_states))
request.cls.states = partial(product, *states)
request.cls.available_states = list(product(*available_states))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment