Commit e251012e by Vedran Karacic Committed by Vedran Karačić

Filter API data by partner

LEARNER-1119
parent 0c869dd7
import logging
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db.models import QuerySet
from django.utils.translation import ugettext as _
......@@ -13,7 +12,6 @@ from guardian.shortcuts import get_objects_for_user
from rest_framework.exceptions import NotFound, PermissionDenied
from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Organization, Program
......@@ -93,7 +91,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
# Return data for the default partner, if no partner is requested
if not any(field in filters for field in ('partner', 'partner_exact')):
filters['partner'] = Partner.objects.get(pk=settings.DEFAULT_PARTNER_ID).short_code
filters['partner'] = request.site.partner.short_code
return filters
......
......@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
tags = TagListSerializerField()
@classmethod
def prefetch_queryset(cls):
return Organization.objects.all().select_related('partner').prefetch_related('tags')
def prefetch_queryset(cls, partner):
return Organization.objects.filter(partner=partner).select_related('partner').prefetch_related('tags')
class Meta(MinimalOrganizationSerializer.Meta):
fields = MinimalOrganizationSerializer.Meta.fields + (
......@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
# Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.all()
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
)
class Meta(MinimalCourseSerializer.Meta):
......@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
programs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.all()
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
)
def get_course_runs(self, course):
......@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.all()
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
)
def get_course_runs(self, course):
......@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all())
@classmethod
def prefetch_queryset(cls):
return Program.objects.all().select_related('type', 'partner').prefetch_related(
def prefetch_queryset(cls, partner):
return Program.objects.filter(partner=partner).select_related('type', 'partner').prefetch_related(
'excluded_course_runs',
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
......@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
applicable_seat_types = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls):
def prefetch_queryset(cls, partner):
"""
Prefetch the related objects that will be serialized with a `Program`.
......@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching).
"""
return Program.objects.all().select_related('type', 'video', 'partner').prefetch_related(
return Program.objects.filter(partner=partner).select_related('type', 'video', 'partner').prefetch_related(
'excluded_course_runs',
'expected_learning_items',
'faq',
......@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
'type__applicable_seat_types',
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()),
)
......
from django.conf import settings
from django.contrib.sites.models import Site
from django.test import RequestFactory
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory
class SiteMixin(object):
def setUp(self):
super(SiteMixin, self).setUp()
domain = 'testserver.fake'
self.client = self.client_class(SERVER_NAME=domain)
Site.objects.all().delete()
self.site = SiteFactory(id=settings.SITE_ID, domain=domain)
self.partner = PartnerFactory(site=self.site)
self.request = RequestFactory(SERVER_NAME=self.site.domain).get('')
self.request.site = self.site
......@@ -21,6 +21,7 @@ from course_discovery.apps.api.serializers import (
ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer,
TypeaheadProgramSearchSerializer, VideoSerializer
)
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory
......@@ -96,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
class MinimalCourseSerializerTests(TestCase):
class MinimalCourseSerializerTests(SiteMixin, TestCase):
serializer_class = MinimalCourseSerializer
def get_expected_data(self, course, request):
......@@ -114,8 +115,8 @@ class MinimalCourseSerializerTests(TestCase):
def test_data(self):
request = make_request()
organizations = OrganizationFactory()
course = CourseFactory(authoring_organizations=[organizations])
organizations = OrganizationFactory(partner=self.partner)
course = CourseFactory(authoring_organizations=[organizations], partner=self.partner)
CourseRunFactory.create_batch(2, course=course)
serializer = self.serializer_class(course, context={'request': request})
expected = self.get_expected_data(course, request)
......@@ -178,9 +179,10 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
def setUp(self):
super().setUp()
self.request = make_request()
self.course = CourseFactory()
self.course = CourseFactory(partner=self.partner)
self.deleted_program = ProgramFactory(
courses=[self.course],
partner=self.partner,
status=ProgramStatus.Deleted
)
......
import ddt
import pytest
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import PermissionDenied
from django.test import RequestFactory, TestCase
from django.urls import reverse
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.views import api_docs_permission_denied_handler
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
@pytest.mark.django_db
class TestApiDocs:
class TestApiDocs(APITestCase):
"""
Regression tests introduced following LEARNER-1590.
"""
path = reverse('api_docs')
def test_api_docs(self, admin_client):
def test_api_docs(self):
"""
Verify that the API docs are available to authenticated clients.
"""
PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
response = admin_client.get(self.path)
user = UserFactory(is_staff=True)
self.client.login(username=user.username, password=USER_PASSWORD)
response = self.client.get(self.path)
assert response.status_code == 200
def test_api_docs_redirect(self, client):
def test_api_docs_redirect(self):
"""
Verify that unauthenticated clients are redirected.
"""
response = client.get(self.path)
response = self.client.get(self.path)
assert response.status_code == 302
......
......@@ -4,6 +4,7 @@ import json
import responses
from django.conf import settings
from rest_framework.test import APITestCase as RestAPITestCase
from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.serializers import (
......@@ -11,6 +12,7 @@ from course_discovery.apps.api.serializers import (
CourseWithProgramsSerializer, FlattenedCourseRunWithCourseSerializer, MinimalProgramSerializer,
OrganizationSerializer, PersonSerializer, ProgramSerializer, ProgramTypeSerializer
)
from course_discovery.apps.api.tests.mixins import SiteMixin
class SerializationMixin(object):
......@@ -88,3 +90,7 @@ class OAuth2Mixin(object):
content_type='application/json',
status=status
)
class APITestCase(SiteMixin, RestAPITestCase):
pass
......@@ -7,10 +7,9 @@ import ddt
import pytz
from lxml import etree
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import AffiliateWindowSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
......@@ -46,8 +45,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
def test_affiliate_with_supported_seats(self):
""" Verify that endpoint returns course runs for verified and professional seats only. """
with self.assertNumQueries(8):
response = self.client.get(self.affiliate_url)
response = self.client.get(self.affiliate_url)
self.assertEqual(response.status_code, 200)
root = ET.fromstring(response.content)
......@@ -130,7 +128,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
# Superusers can view all catalogs
self.client.force_authenticate(superuser)
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
......@@ -140,7 +138,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
self.assertEqual(response.status_code, 403)
catalog.viewers = [self.user]
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
......
......@@ -8,10 +8,9 @@ import pytz
import responses
from django.contrib.auth import get_user_model
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
from course_discovery.apps.api.v1.tests.test_views.mixins import OAuth2Mixin, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, OAuth2Mixin, SerializationMixin
from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory
......@@ -31,6 +30,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
def setUp(self):
super(CatalogViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.force_authenticate(self.user)
self.catalog = CatalogFactory(query='title:abc*')
enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=30)
......@@ -172,7 +172,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# to be included.
filtered_course_run = CourseRunFactory(course=course)
with self.assertNumQueries(16):
with self.assertNumQueries(18):
response = self.client.get(url)
assert response.status_code == 200
......@@ -185,8 +185,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# Any course appearing in the response must have at least one serialized run.
assert len(response.data['results'][0]['course_runs']) > 0
else:
with self.assertNumQueries(2):
response = self.client.get(url)
response = self.client.get(url)
assert response.status_code == 200
assert response.data['results'] == []
......@@ -218,7 +217,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id})
with self.assertNumQueries(17):
with self.assertNumQueries(18):
response = self.client.get(url)
course_run = self.serialize_catalog_flat_course_run(self.course_run)
......
......@@ -4,19 +4,16 @@ import urllib
import ddt
import pytz
from django.conf import settings
from django.db.models.functions import Lower
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory, APITestCase
from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import (
CourseRunFactory, PartnerFactory, ProgramFactory, SeatFactory
)
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory, SeatFactory
@ddt.ddt
......@@ -25,10 +22,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
super(CourseRunViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
self.course_run = CourseRunFactory(course__partner=self.partner)
self.course_run_2 = CourseRunFactory(course__partner=self.partner)
self.refresh_index()
......@@ -170,15 +163,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted)
def test_list_query_invalid_partner(self):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query = 'title:Some random title'
url = '{root}?q={query}&partner={partner}'.format(root=reverse('api:v1:course_run-list'), query=query,
partner='foo')
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def assert_list_results(self, url, expected, extra_context=None):
expected = sorted(expected, key=lambda course_run: course_run.key.lower())
response = self.client.get(url)
......@@ -256,7 +240,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
'course_run_ids': self.course_run.key,
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url)
assert response.status_code == 200
self.assertEqual(
......@@ -268,18 +251,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
}
)
def test_contains_single_course_run_invalid_partner(self):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs = urllib.parse.urlencode({
'query': 'id:course*',
'course_run_ids': self.course_run.key,
'partner': 'foo'
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url)
assert response.status_code == 400
def test_contains_multiple_course_runs(self):
qs = urllib.parse.urlencode({
'query': 'id:course*',
......
......@@ -4,9 +4,8 @@ import ddt
import pytz
from django.db.models.functions import Lower
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import Course
......@@ -22,14 +21,15 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
def setUp(self):
super(CourseViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.course = CourseFactory()
self.course = CourseFactory(partner=self.partner)
def test_get(self):
""" Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(18):
with self.assertNumQueries(20):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course))
......@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(11):
with self.assertNumQueries(13):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), [])
......@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1'
with self.assertNumQueries(22):
with self.assertNumQueries(24):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(
......@@ -187,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
with self.assertNumQueries(24):
with self.assertNumQueries(26):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(
......@@ -203,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(37):
with self.assertNumQueries(39):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
def test_list_key_filter(self):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
courses = CourseFactory.create_batch(3)
courses = CourseFactory.create_batch(3, partner=self.partner)
courses = sorted(courses, key=lambda course: course.key.lower())
keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(37):
with self.assertNumQueries(39):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
import uuid
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import Organization, OrganizationFactory
......@@ -14,6 +13,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def setUp(self):
super(OrganizationViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
def test_authentication(self):
......@@ -27,17 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def assert_response_data_valid(self, response, organizations, many=True):
""" Asserts the response data (only) contains the expected organizations. """
actual = response.data
serializer_data = self.serialize_organization(organizations, many=many)
if many:
actual = actual['results']
self.assertEqual(actual, self.serialize_organization(organizations, many=many))
self.assertCountEqual(actual, serializer_data)
def assert_list_uuid_filter(self, organizations):
def assert_list_uuid_filter(self, organizations, expected_query_count):
""" Asserts the list endpoint supports filtering by UUID. """
with self.assertNumQueries(5):
organizations = sorted(organizations, key=lambda o: o.created)
with self.assertNumQueries(expected_query_count):
uuids = ','.join([organization.uuid.hex for organization in organizations])
url = '{root}?uuids={uuids}'.format(root=self.list_path, uuids=uuids)
response = self.client.get(url)
......@@ -45,9 +45,8 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
self.assertEqual(response.status_code, 200)
self.assert_response_data_valid(response, organizations)
def assert_list_tag_filter(self, organizations, tags, expected_query_count=5):
def assert_list_tag_filter(self, organizations, tags, expected_query_count=7):
""" Asserts the list endpoint supports filtering by tags. """
with self.assertNumQueries(expected_query_count):
tags = ','.join(tags)
url = '{root}?tags={tags}'.format(root=self.list_path, tags=tags)
......@@ -58,10 +57,9 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list(self):
""" Verify the endpoint returns a list of all organizations. """
OrganizationFactory.create_batch(3, partner=self.partner)
OrganizationFactory.create_batch(3)
with self.assertNumQueries(5):
with self.assertNumQueries(7):
response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 200)
......@@ -70,22 +68,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list_uuid_filter(self):
""" Verify the endpoint returns a list of organizations filtered by UUID. """
organizations = OrganizationFactory.create_batch(3)
organizations = OrganizationFactory.create_batch(3, partner=self.partner)
# Test with a single UUID
self.assert_list_uuid_filter([organizations[0]])
self.assert_list_uuid_filter([organizations[0]], 7)
# Test with multiple UUIDs
self.assert_list_uuid_filter(organizations)
self.assert_list_uuid_filter(organizations, 7)
def test_list_tag_filter(self):
""" Verify the endpoint returns a list of organizations filtered by tag. """
tag = 'test-org'
organizations = OrganizationFactory.create_batch(2)
organizations = OrganizationFactory.create_batch(2, partner=self.partner)
# If no organizations have been tagged, the endpoint should not return any data
self.assert_list_tag_filter([], [tag], expected_query_count=3)
self.assert_list_tag_filter([], [tag], expected_query_count=5)
# Tagged organizations should be returned
organizations[0].tags.add(tag)
......@@ -99,7 +97,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns details for a single organization. """
organization = OrganizationFactory()
organization = OrganizationFactory(partner=self.partner)
url = reverse('api:v1:organization-detail', kwargs={'uuid': organization.uuid})
response = self.client.get(url)
......
# pylint: disable=redefined-builtin,no-member
import ddt
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import IntegrityError
from mock import mock
......@@ -8,20 +7,20 @@ from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.views.people import logger as people_logger
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.models import Person
from course_discovery.apps.course_metadata.people import MarketingSitePeople
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import (OrganizationFactory, PartnerFactory, PersonFactory,
PositionFactory)
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory, PositionFactory
User = get_user_model()
@ddt.ddt
class PersonViewSetTests(SerializationMixin, APITestCase):
class PersonViewSetTests(SerializationMixin, SiteMixin, APITestCase):
""" Tests for the person resource. """
people_list_url = reverse('api:v1:person-list')
......@@ -29,13 +28,9 @@ class PersonViewSetTests(SerializationMixin, APITestCase):
super(PersonViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user)
self.person = PersonFactory()
PositionFactory(person=self.person)
self.organization = OrganizationFactory()
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
self.person = PersonFactory(partner=self.partner)
self.organization = OrganizationFactory(partner=self.partner)
PositionFactory(person=self.person, organization=self.organization)
toggle_switch('publish_person_to_marketing_site', True)
self.expected_node = {
'resource': 'node', ''
......
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.models import ProgramType
from course_discovery.apps.course_metadata.tests.factories import ProgramTypeFactory
......@@ -28,7 +27,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all program types. """
ProgramTypeFactory.create_batch(4)
expected = ProgramType.objects.all()
with self.assertNumQueries(5):
with self.assertNumQueries(6):
response = self.client.get(self.list_path)
assert response.status_code == 200
......@@ -39,7 +38,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
program_type = ProgramTypeFactory()
url = reverse('api:v1:program_type-detail', kwargs={'slug': program_type.slug})
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response = self.client.get(url)
assert response.status_code == 200
......
......@@ -3,10 +3,9 @@ import urllib.parse
import ddt
from django.core.cache import cache
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import MinimalProgramSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.views.programs import ProgramViewSet
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
......@@ -25,16 +24,17 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def setUp(self):
super(ProgramViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
# Clear the cache between test cases, so they don't interfere with each other.
cache.clear()
def create_program(self):
organizations = [OrganizationFactory()]
organizations = [OrganizationFactory(partner=self.partner)]
person = PersonFactory()
course = CourseFactory()
course = CourseFactory(partner=self.partner)
CourseRunFactory(course=course, staff=[person])
program = ProgramFactory(
......@@ -46,7 +46,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
expected_learning_items=ExpectedLearningItemFactory.create_batch(1),
job_outlook_items=JobOutlookItemFactory.create_batch(1),
banner_image=make_image_file('test_banner.jpg'),
video=VideoFactory()
video=VideoFactory(),
partner=self.partner
)
return program
......@@ -73,14 +74,14 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """
program = self.create_program()
with self.assertNumQueries(37):
with self.assertNumQueries(39):
response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
assert response.data == self.serialize_program(program)
# Verify that repeated retrieve requests use the cache.
with self.assertNumQueries(2):
with self.assertNumQueries(4):
self.assert_retrieve_success(program)
# Verify that requests including querystring parameters are cached separately.
......@@ -90,22 +91,25 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
@ddt.data(True, False)
def test_retrieve_with_sorting_flag(self, order_courses_by_start_date):
""" Verify the number of queries is the same with sorting flag set to true. """
course_list = CourseFactory.create_batch(3)
course_list = CourseFactory.create_batch(3, partner=self.partner)
for course in course_list:
CourseRunFactory(course=course)
program = ProgramFactory(courses=course_list, order_courses_by_start_date=order_courses_by_start_date)
program = ProgramFactory(
courses=course_list,
order_courses_by_start_date=order_courses_by_start_date,
partner=self.partner)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
with self.assertNumQueries(26):
with self.assertNumQueries(28):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
self.assertEqual(course_list, list(program.courses.all())) # pylint: disable=no-member
def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory()
program = ProgramFactory(courses=[course])
with self.assertNumQueries(20):
course = CourseFactory(partner=self.partner)
program = ProgramFactory(courses=[course], partner=self.partner)
with self.assertNumQueries(22):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
......@@ -135,18 +139,18 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all programs. """
expected = [self.create_program() for __ in range(3)]
expected.reverse()
self.assert_list_results(self.list_path, expected, 12)
self.assert_list_results(self.list_path, expected, 14)
# Verify that repeated list requests use the cache.
self.assert_list_results(self.list_path, expected, 2)
self.assert_list_results(self.list_path, expected, 4)
def test_uuids_only(self):
"""
Verify that the list view returns a simply list of UUIDs when the
uuids_only query parameter is passed.
"""
active = ProgramFactory.create_batch(3)
retired = [ProgramFactory(status=ProgramStatus.Retired)]
active = ProgramFactory.create_batch(3, partner=self.partner)
retired = [ProgramFactory(status=ProgramStatus.Retired, partner=self.partner)]
programs = active + retired
querystring = {'uuids_only': 1}
......@@ -165,47 +169,47 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name)
program = ProgramFactory(type__name=program_type_name, partner=self.partner)
url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program], 8)
self.assert_list_results(url, [program], 10)
url = self.list_path + '?type=bar'
self.assert_list_results(url, [], 3)
self.assert_list_results(url, [], 5)
def test_filter_by_types(self):
""" Verify that the endpoint filters programs to those matching the provided ProgramType slugs. """
expected = ProgramFactory.create_batch(2)
expected = ProgramFactory.create_batch(2, partner=self.partner)
expected.reverse()
type_slugs = [p.type.slug for p in expected]
url = self.list_path + '?types=' + ','.join(type_slugs)
# Create a third program, which should be filtered out.
ProgramFactory()
ProgramFactory(partner=self.partner)
self.assert_list_results(url, expected, 8)
self.assert_list_results(url, expected, 10)
def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
expected = ProgramFactory.create_batch(2)
expected = ProgramFactory.create_batch(2, partner=self.partner)
expected.reverse()
uuids = [str(p.uuid) for p in expected]
url = self.list_path + '?uuids=' + ','.join(uuids)
# Create a third program, which should be filtered out.
ProgramFactory()
ProgramFactory(partner=self.partner)
self.assert_list_results(url, expected, 8)
self.assert_list_results(url, expected, 10)
@ddt.data(
(ProgramStatus.Unpublished, False, 3),
(ProgramStatus.Active, True, 8),
(ProgramStatus.Unpublished, False, 5),
(ProgramStatus.Active, True, 10),
)
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='')
programs = ProgramFactory.create_batch(3, status=status)
ProgramFactory(marketing_slug='', partner=self.partner)
programs = ProgramFactory.create_batch(3, status=status, partner=self.partner)
programs.reverse()
expected = programs if is_marketable else []
......@@ -214,40 +218,40 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_status(self):
""" Verify the endpoint allows programs to filtered by one, or more, statuses. """
active = ProgramFactory(status=ProgramStatus.Active)
retired = ProgramFactory(status=ProgramStatus.Retired)
active = ProgramFactory(status=ProgramStatus.Active, partner=self.partner)
retired = ProgramFactory(status=ProgramStatus.Retired, partner=self.partner)
url = self.list_path + '?status=active'
self.assert_list_results(url, [active], 8)
self.assert_list_results(url, [active], 10)
url = self.list_path + '?status=retired'
self.assert_list_results(url, [retired], 8)
self.assert_list_results(url, [retired], 10)
url = self.list_path + '?status=active&status=retired'
self.assert_list_results(url, [retired, active], 8)
self.assert_list_results(url, [retired, active], 10)
def test_filter_by_hidden(self):
""" Endpoint should filter programs by their hidden attribute value. """
hidden = ProgramFactory(hidden=True)
not_hidden = ProgramFactory(hidden=False)
hidden = ProgramFactory(hidden=True, partner=self.partner)
not_hidden = ProgramFactory(hidden=False, partner=self.partner)
url = self.list_path + '?hidden=True'
self.assert_list_results(url, [hidden], 8)
self.assert_list_results(url, [hidden], 10)
url = self.list_path + '?hidden=False'
self.assert_list_results(url, [not_hidden], 8)
self.assert_list_results(url, [not_hidden], 10)
url = self.list_path + '?hidden=1'
self.assert_list_results(url, [hidden], 8)
self.assert_list_results(url, [hidden], 10)
url = self.list_path + '?hidden=0'
self.assert_list_results(url, [not_hidden], 8)
self.assert_list_results(url, [not_hidden], 10)
def test_list_exclude_utm(self):
""" Verify the endpoint returns marketing URLs without UTM parameters. """
url = self.list_path + '?exclude_utm=1'
program = self.create_program()
self.assert_list_results(url, [program], 12, extra_context={'exclude_utm': 1})
self.assert_list_results(url, [program], 14, extra_context={'exclude_utm': 1})
def test_minimal_serializer_use(self):
""" Verify that the list view uses the minimal serializer. """
......
......@@ -3,13 +3,12 @@ import json
import urllib.parse
import ddt
from django.conf import settings
from django.urls import reverse
from haystack.query import SearchQuerySet
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer)
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.views.search import TypeaheadSearchView
from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
......@@ -88,14 +87,8 @@ class SynonymTestMixin:
self.assertDictEqual(response1, response2)
class DefaultPartnerMixin:
def setUp(self):
super(DefaultPartnerMixin, self).setUp()
self.partner = PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
@ddt.ddt
class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
""" Tests for CourseRunSearchViewSet. """
faceted_path = reverse('api:v1:search-course_runs-facets')
......@@ -162,7 +155,9 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
return course_run, response_data
def build_facet_url(self, params):
return 'http://testserver{path}?{query}'.format(path=self.faceted_path, query=urllib.parse.urlencode(params))
return 'http://testserver.fake{path}?{query}'.format(
path=self.faceted_path, query=urllib.parse.urlencode(params)
)
def test_invalid_query_facet(self):
""" Verify the endpoint returns HTTP 400 if an invalid facet is requested. """
......@@ -271,7 +266,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
)
self.reindex_courses(program)
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
......@@ -295,7 +290,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory(courses=[course_run.course], status=program_status)
self.reindex_courses(active_program)
with self.assertNumQueries(5):
with self.assertNumQueries(6):
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
......@@ -313,7 +308,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
@ddt.ddt
class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets')
......@@ -438,7 +433,7 @@ class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
assert expected == actual
class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-typeahead')
......@@ -620,23 +615,3 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
self.serialize_program(harvard_program)]
}
self.assertDictEqual(response.data, expected)
def test_typeahead_partner_filter(self):
""" Ensure that a partner param limits results to that partner. """
course_runs = []
programs = []
for partner in ['edx', 'other']:
title = 'Belongs to partner ' + partner
partner = PartnerFactory(short_code=partner)
course_runs.append(CourseRunFactory(title=title, course=CourseFactory(partner=partner)))
programs.append(ProgramFactory(
title=title, partner=partner,
status=ProgramStatus.Active
))
response = self.get_response({'q': 'partner'}, 'edx')
self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0]
edx_program = programs[0]
self.assertDictEqual(response.data, {'course_runs': [self.serialize_course_run(edx_course_run)],
'programs': [self.serialize_program(edx_program)]})
......@@ -35,18 +35,3 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
......@@ -94,6 +94,7 @@ class CatalogViewSet(viewsets.ModelViewSet):
course_runs = CourseRun.objects.active().enrollable().marketable()
queryset = serializers.CatalogCourseSerializer.prefetch_queryset(
self.request.site.partner,
queryset=queryset,
course_runs=course_runs
)
......
......@@ -9,14 +9,13 @@ from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import ProxiedPagination
from course_discovery.apps.api.utils import get_query_param
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member
class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
class CourseRunViewSet(viewsets.ModelViewSet):
""" CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter
......@@ -43,7 +42,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
multiple: false
"""
q = self.request.query_params.get('q')
partner = self.get_partner()
partner = self.request.site.partner
if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
......@@ -80,12 +79,6 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
type: string
paramType: query
multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
- name: hidden
description: Filter based on wether the course run is hidden from search.
required: false
......@@ -166,7 +159,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids')
partner = self.get_partner()
partner = self.request.site.partner
if query and course_run_ids:
course_run_ids = course_run_ids.split(',')
......
......@@ -18,7 +18,6 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
filter_class = filters.CourseFilter
lookup_field = 'key'
lookup_value_regex = COURSE_ID_REGEX
queryset = Course.objects.all()
permission_classes = (IsAuthenticated,)
serializer_class = serializers.CourseWithProgramsSerializer
......@@ -27,16 +26,17 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class = ProxiedPagination
def get_queryset(self):
partner = self.request.site.partner
q = self.request.query_params.get('q')
if q:
queryset = Course.search(q)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset, partner=partner)
else:
if get_query_param(self.request, 'include_hidden_course_runs'):
course_runs = CourseRun.objects.all()
course_runs = CourseRun.objects.filter(course__partner=partner)
else:
course_runs = CourseRun.objects.exclude(hidden=True)
course_runs = CourseRun.objects.filter(course__partner=partner).exclude(hidden=True)
if get_query_param(self.request, 'marketable_course_runs_only'):
course_runs = course_runs.marketable().active()
......@@ -49,7 +49,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
queryset = self.get_serializer_class().prefetch_queryset(
queryset=self.queryset,
course_runs=course_runs
course_runs=course_runs,
partner=partner
)
return queryset.order_by(Lower('key'))
......
......@@ -15,13 +15,16 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+'
permission_classes = (IsAuthenticated,)
queryset = serializers.OrganizationSerializer.prefetch_queryset()
serializer_class = serializers.OrganizationSerializer
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# versions of this API should only support the system default, PageNumberPagination.
pagination_class = ProxiedPagination
def get_queryset(self):
partner = self.request.site.partner
return serializers.OrganizationSerializer.prefetch_queryset(partner=partner)
def list(self, request, *args, **kwargs):
""" Retrieve a list of all organizations. """
return super(OrganizationViewSet, self).list(request, *args, **kwargs)
......
......@@ -7,7 +7,6 @@ from rest_framework.response import Response
from course_discovery.apps.api import serializers
from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.exceptions import MarketingSiteAPIClientException, PersonToMarketingException
from course_discovery.apps.course_metadata.people import MarketingSitePeople
......@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=no-member
class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
class PersonViewSet(viewsets.ModelViewSet):
""" PersonSerializer resource. """
lookup_field = 'uuid'
......@@ -30,7 +29,7 @@ class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
""" Create a new person. """
person_data = request.data
partner = self.get_partner()
partner = request.site.partner
person_data['partner'] = partner.id
serializer = self.get_serializer(data=person_data)
serializer.is_valid(raise_exception=True)
......
......@@ -32,7 +32,8 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
return self.get_serializer_class().prefetch_queryset()
partner = self.request.site.partner
return self.get_serializer_class().prefetch_queryset(partner)
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
......@@ -89,7 +90,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
if get_query_param(self.request, 'uuids_only'):
# DRF serializers don't have good support for simple, flat
# representations like the one we want here.
queryset = self.filter_queryset(Program.objects.all())
queryset = self.filter_queryset(Program.objects.filter(partner=self.request.site.partner))
uuids = queryset.values_list('uuid', flat=True)
return Response(uuids)
......
......@@ -12,7 +12,6 @@ from rest_framework.response import Response
from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
......@@ -119,7 +118,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(PartnerMixin, APIView):
class TypeaheadSearchView(APIView):
""" Typeahead for courses and programs. """
RESULT_COUNT = 3
permission_classes = (IsAuthenticated,)
......@@ -181,7 +180,7 @@ class TypeaheadSearchView(PartnerMixin, APIView):
type: string
"""
query = request.query_params.get('q')
partner = self.get_partner()
partner = request.site.partner
if not query:
raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query, partner)
......
......@@ -3,10 +3,11 @@ import json
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
class UserAutocompleteTests(TestCase):
class UserAutocompleteTests(SiteMixin, TestCase):
""" Tests for user autocomplete lookups."""
def setUp(self):
......
from django.conf import settings
from django.core.cache import cache
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import UserThrottleRate
from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.throttles import OverridableUserRateThrottle
class RateLimitingTest(APITestCase):
class RateLimitingTest(SiteMixin, APITestCase):
"""
Testing rate limiting of API calls.
"""
......@@ -16,8 +16,6 @@ class RateLimitingTest(APITestCase):
def setUp(self):
super(RateLimitingTest, self).setUp()
PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
self.url = reverse('api_docs')
self.user = UserFactory()
self.client.login(username=self.user.username, password=USER_PASSWORD)
......
......@@ -9,18 +9,20 @@ from django.test.utils import override_settings
from django.urls import reverse
from django.utils.encoding import force_text
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.constants import Status
User = get_user_model()
class HealthTests(TestCase):
class HealthTests(SiteMixin, TestCase):
"""Tests of the health endpoint."""
def test_all_services_available(self):
"""Test that the endpoint reports when all services are healthy."""
self._assert_health(200, Status.OK, Status.OK)
@mock.patch('django.contrib.sites.middleware.get_current_site', mock.Mock(return_value=None))
def test_database_outage(self):
"""Test that the endpoint reports when the database is unavailable."""
with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError):
......@@ -42,7 +44,7 @@ class HealthTests(TestCase):
self.assertJSONEqual(force_text(response.content), expected_data)
class AutoAuthTests(TestCase):
class AutoAuthTests(SiteMixin, TestCase):
""" Auto Auth view tests. """
AUTO_AUTH_PATH = reverse('auto_auth')
......
......@@ -11,6 +11,7 @@ from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import Select
from selenium.webdriver.support.wait import WebDriverWait
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
......@@ -23,7 +24,7 @@ from course_discovery.apps.course_metadata.tests import factories
# pylint: disable=no-member
@ddt.ddt
class AdminTests(TestCase):
class AdminTests(SiteMixin, TestCase):
""" Tests Admin page."""
def setUp(self):
......@@ -190,7 +191,7 @@ class AdminTests(TestCase):
self.assertEqual(response.status_code, 200)
class ProgramAdminFunctionalTests(LiveServerTestCase):
class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
""" Functional Tests for Admin page."""
# Required for access to initial data loaded in migrations (e.g., LanguageTags).
serialized_rollback = True
......@@ -224,7 +225,6 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
def setUp(self):
super().setUp()
# ContentTypeManager uses a cache to speed up ContentType retrieval. This
# cache persists across tests. This is fine in the context of a regular
# TestCase which uses a transaction to reset the database between tests.
......@@ -238,6 +238,9 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
# stale ContentType objects from being used.
ContentType.objects.clear_cache()
self.site.domain = self.live_server_url.strip('http://')
self.site.save()
self.course_runs = factories.CourseRunFactory.create_batch(2)
self.courses = [course_run.course for course_run in self.course_runs]
......@@ -349,7 +352,7 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
self.assertEqual(self.program.subtitle, subtitle)
class ProgramEligibilityFilterTests(TestCase):
class ProgramEligibilityFilterTests(SiteMixin, TestCase):
""" Tests for Program Eligibility Filter class. """
parameter_name = 'eligible_for_one_click_purchase'
......
......@@ -5,6 +5,7 @@ import ddt
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory
......@@ -16,7 +17,7 @@ from course_discovery.apps.publisher.tests import factories
@ddt.ddt
class AutocompleteTests(TestCase):
class AutocompleteTests(SiteMixin, TestCase):
""" Tests for autocomplete lookups."""
def setUp(self):
super(AutocompleteTests, self).setUp()
......@@ -118,7 +119,7 @@ class AutocompleteTests(TestCase):
@ddt.ddt
class AutoCompletePersonTests(TestCase):
class AutoCompletePersonTests(SiteMixin, TestCase):
"""
Tests for person autocomplete lookups
"""
......
......@@ -2,17 +2,17 @@ import datetime
import urllib.parse
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.tests.test_views.test_search import (
DefaultPartnerMixin, ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin
ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin
)
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory, ProgramFactory
from course_discovery.apps.edx_catalog_extensions.api.serializers import DistinctCountsAggregateFacetSearchSerializer
class DistinctCountsAggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin,
class DistinctCountsAggregateSearchViewSetTests(SerializationMixin, LoginMixin,
ElasticsearchTestMixin, SynonymTestMixin, APITestCase):
path = reverse('extensions:api:v1:search-all-facets')
......
......@@ -4,13 +4,14 @@ import ddt
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint: disable=no-member
@ddt.ddt
class AutocompleteTests(TestCase):
class AutocompleteTests(SiteMixin, TestCase):
""" Tests for autocomplete lookups."""
def setUp(self):
super(AutocompleteTests, self).setUp()
......
......@@ -40,7 +40,8 @@ class CourseUserRoleSerializer(serializers.ModelSerializer):
former_user = instance.user
instance = super(CourseUserRoleSerializer, self).update(instance, validated_data)
if not instance.role == PublisherUserRole.CourseTeam:
send_change_role_assignment_email(instance, former_user)
request = self.context['request']
send_change_role_assignment_email(instance, former_user, request.site)
return instance
......@@ -104,6 +105,7 @@ class CourseRunSerializer(serializers.ModelSerializer):
instance = super(CourseRunSerializer, self).update(instance, validated_data)
preview_url = validated_data.get('preview_url')
lms_course_id = validated_data.get('lms_course_id')
request = self.context['request']
if preview_url:
# Change ownership to CourseTeam.
......@@ -111,10 +113,10 @@ class CourseRunSerializer(serializers.ModelSerializer):
if waffle.switch_is_active('enable_publisher_email_notifications'):
if preview_url:
send_email_preview_page_is_available(instance)
send_email_preview_page_is_available(instance, site=request.site)
elif lms_course_id:
send_email_for_studio_instance_created(instance)
send_email_for_studio_instance_created(instance, site=request.site)
return instance
......@@ -167,7 +169,7 @@ class CourseStateSerializer(serializers.ModelSerializer):
state = validated_data.get('name')
request = self.context.get('request')
try:
instance.change_state(state=state, user=request.user)
instance.change_state(state=state, user=request.user, site=request.site)
except TransitionNotAllowed:
# pylint: disable=no-member
raise serializers.ValidationError(
......@@ -204,7 +206,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
if state:
try:
instance.change_state(state=state, user=request.user)
instance.change_state(state=state, user=request.user, site=request.site)
except TransitionNotAllowed:
# pylint: disable=no-member
raise serializers.ValidationError(
......@@ -223,6 +225,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
instance.save()
if waffle.switch_is_active('enable_publisher_email_notifications'):
send_email_preview_accepted(instance.course_run)
send_email_preview_accepted(instance.course_run, request.site)
return instance
......@@ -5,6 +5,7 @@ from django.test import RequestFactory, TestCase
from opaque_keys.edx.keys import CourseKey
from rest_framework.exceptions import ValidationError
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch
......@@ -20,7 +21,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour
OrganizationExtensionFactory, SeatFactory)
class CourseUserRoleSerializerTests(TestCase):
class CourseUserRoleSerializerTests(SiteMixin, TestCase):
serializer_class = CourseUserRoleSerializer
def setUp(self):
......@@ -28,6 +29,7 @@ class CourseUserRoleSerializerTests(TestCase):
self.request = RequestFactory()
self.course_user_role = CourseUserRoleFactory(role=PublisherUserRole.MarketingReviewer)
self.request.user = self.course_user_role.user
self.request.site = self.site
def get_expected_data(self):
""" Helper method which will return expected serialize data. """
......@@ -138,7 +140,7 @@ class CourseRunSerializerTests(TestCase):
"""
self.course_run.preview_url = ''
self.course_run.save()
serializer = self.serializer_class(self.course_run)
serializer = self.serializer_class(self.course_run, context={'request': self.request})
serializer.update(self.course_run, {'preview_url': 'https://example.com/abc/course'})
self.assertEqual(self.course_state.owner_role, PublisherUserRole.CourseTeam)
......@@ -246,13 +248,12 @@ class CourseRevisionSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected)
class CourseStateSerializerTests(TestCase):
class CourseStateSerializerTests(SiteMixin, TestCase):
serializer_class = CourseStateSerializer
def setUp(self):
super(CourseStateSerializerTests, self).setUp()
self.course_state = CourseStateFactory(name=CourseStateChoices.Draft)
self.request = RequestFactory()
self.user = UserFactory()
self.request.user = self.user
......@@ -289,14 +290,13 @@ class CourseStateSerializerTests(TestCase):
serializer.update(self.course_state, data)
class CourseRunStateSerializerTests(TestCase):
class CourseRunStateSerializerTests(SiteMixin, TestCase):
serializer_class = CourseRunStateSerializer
def setUp(self):
super(CourseRunStateSerializerTests, self).setUp()
self.run_state = CourseRunStateFactory(name=CourseRunStateChoices.Draft)
self.course_run = self.run_state.course_run
self.request = RequestFactory()
self.user = UserFactory()
self.request.user = self.user
CourseStateFactory(name=CourseStateChoices.Approved, course=self.course_run.course)
......
......@@ -4,7 +4,6 @@ from urllib.parse import quote
import ddt
from django.contrib.auth.models import Group
from django.contrib.sites.models import Site
from django.core import mail
from django.db import IntegrityError
from django.test import TestCase
......@@ -14,6 +13,7 @@ from mock import mock, patch
from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch
......@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories
@ddt.ddt
class CourseRoleAssignmentViewTests(TestCase):
class CourseRoleAssignmentViewTests(SiteMixin, TestCase):
def setUp(self):
super(CourseRoleAssignmentViewTests, self).setUp()
......@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(TestCase):
self.assertEqual(len(mail.outbox), 1)
class OrganizationGroupUserViewTests(TestCase):
class OrganizationGroupUserViewTests(SiteMixin, TestCase):
def setUp(self):
super(OrganizationGroupUserViewTests, self).setUp()
......@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(TestCase):
)
class UpdateCourseRunViewTests(TestCase):
class UpdateCourseRunViewTests(SiteMixin, TestCase):
def setUp(self):
super(UpdateCourseRunViewTests, self).setUp()
......@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(TestCase):
body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
def test_update_preview_url(self):
......@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(TestCase):
self.assertEqual(len(mail.outbox), 0)
class CourseRevisionDetailViewTests(TestCase):
class CourseRevisionDetailViewTests(SiteMixin, TestCase):
def setUp(self):
super(CourseRevisionDetailViewTests, self).setUp()
......@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(TestCase):
return self.client.get(path=course_revision_path)
class ChangeCourseStateViewTests(TestCase):
class ChangeCourseStateViewTests(SiteMixin, TestCase):
def setUp(self):
super(ChangeCourseStateViewTests, self).setUp()
......@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(TestCase):
body = mail.outbox[0].body.strip()
object_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
def test_change_course_state_with_error(self):
......@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(TestCase):
self._assert_email_sent(course_team_user, subject)
class ChangeCourseRunStateViewTests(TestCase):
class ChangeCourseRunStateViewTests(SiteMixin, TestCase):
def setUp(self):
super(ChangeCourseRunStateViewTests, self).setUp()
......@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(TestCase):
self.assertIn('has been published', mail.outbox[0].body.strip())
class RevertCourseByRevisionTests(TestCase):
class RevertCourseByRevisionTests(SiteMixin, TestCase):
def setUp(self):
super(RevertCourseByRevisionTests, self).setUp()
......@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(TestCase):
return self.client.put(path=course_revision_path)
class CoursesAutoCompleteTests(TestCase):
class CoursesAutoCompleteTests(SiteMixin, TestCase):
""" Tests for course autocomplete."""
def setUp(self):
......@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(TestCase):
self.assertEqual(len(data['results']), expected_length)
class AcceptAllByRevisionTests(TestCase):
class AcceptAllByRevisionTests(SiteMixin, TestCase):
def setUp(self):
super(AcceptAllByRevisionTests, self).setUp()
......
import logging
from django.conf import settings
from django.contrib.sites.models import Site
from django.core.mail.message import EmailMultiAlternatives
from django.template.loader import get_template
from django.urls import reverse
......@@ -16,11 +15,12 @@ from course_discovery.apps.publisher.utils import is_email_notification_enabled
logger = logging.getLogger(__name__)
def send_email_for_studio_instance_created(course_run):
def send_email_for_studio_instance_created(course_run, site):
""" Send an email to course team on studio instance creation.
Arguments:
course_run (CourseRun): CourseRun object
site (Site): Current site
"""
try:
course_key = CourseKey.from_string(course_run.lms_course_id)
......@@ -39,7 +39,7 @@ def send_email_for_studio_instance_created(course_run):
context = {
'course_run': course_run,
'course_run_page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=object_path
host=site.domain.strip('/'), path=object_path
),
'course_name': course_run.course.title,
'from_address': from_address,
......@@ -65,12 +65,13 @@ def send_email_for_studio_instance_created(course_run):
raise Exception(error_message)
def send_email_for_course_creation(course, course_run):
def send_email_for_course_creation(course, course_run, site):
""" Send the emails for a course creation.
Arguments:
course (Course): Course object
course_run (CourseRun): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_created.txt'
html_template = 'publisher/email/course_created.html'
......@@ -91,7 +92,7 @@ def send_email_for_course_creation(course, course_run):
'course_team_name': course_team.get_full_name(),
'project_coordinator_name': project_coordinator.get_full_name(),
'dashboard_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=reverse('publisher:publisher_dashboard')
host=site.domain.strip('/'), path=reverse('publisher:publisher_dashboard')
),
'from_address': from_address,
'contact_us_email': project_coordinator.email
......@@ -113,12 +114,13 @@ def send_email_for_course_creation(course, course_run):
)
def send_email_for_send_for_review(course, user):
def send_email_for_send_for_review(course, user, site):
""" Send email when course is submitted for review.
Arguments:
course (Object): Course object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course/send_for_review.txt'
html_template = 'publisher/email/course/send_for_review.html'
......@@ -135,21 +137,22 @@ def send_email_for_send_for_review(course, user):
'course_name': course.title,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
)
}
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications send for review of course %s', course.id)
def send_email_for_mark_as_reviewed(course, user):
def send_email_for_mark_as_reviewed(course, user, site):
""" Send email when course is marked as reviewed.
Arguments:
course (Object): Course object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course/mark_as_reviewed.txt'
html_template = 'publisher/email/course/mark_as_reviewed.html'
......@@ -166,16 +169,16 @@ def send_email_for_mark_as_reviewed(course, user):
'course_name': course.title,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
)
}
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications mark as reviewed of course %s', course.id)
def send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user):
def send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site):
""" Send email for course workflow state change.
Arguments:
......@@ -186,6 +189,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
html_template: (String): Email html template path
context: (Dict): Email template context
recipient_user: (Object): User object
site (Site): Current site
"""
if is_email_notification_enabled(recipient_user):
......@@ -202,7 +206,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'course_page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
host=site.domain.strip('/'), path=course_page_path
)
}
)
......@@ -219,12 +223,13 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
email_msg.send()
def send_email_for_send_for_review_course_run(course_run, user):
def send_email_for_send_for_review_course_run(course_run, user, site):
""" Send email when course-run is submitted for review.
Arguments:
course-run (Object): CourseRun object
user (Object): User object
site (Site): Current site
"""
course = course_run.course
course_key = CourseKey.from_string(course_run.lms_course_id)
......@@ -246,22 +251,23 @@ def send_email_for_send_for_review_course_run(course_run, user):
'run_number': course_key.run,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'project coordinators',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
),
'studio_url': course_run.studio_url
}
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications send for review of course-run %s', course_run.id)
def send_email_for_mark_as_reviewed_course_run(course_run, user):
def send_email_for_mark_as_reviewed_course_run(course_run, user, site):
""" Send email when course-run is marked as reviewed.
Arguments:
course_run (Object): CourseRun object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/mark_as_reviewed.txt'
html_template = 'publisher/email/course_run/mark_as_reviewed.html'
......@@ -284,21 +290,24 @@ def send_email_for_mark_as_reviewed_course_run(course_run, user):
'run_number': course_key.run,
'sender_team': 'course team',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
)
}
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
send_course_workflow_email(
course, user, subject, txt_template, html_template, context, recipient_user, site
)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id)
def send_email_to_publisher(course_run, user):
def send_email_to_publisher(course_run, user, site):
""" Send email to publisher when course-run is marked as reviewed.
Arguments:
course_run (Object): CourseRun object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/mark_as_reviewed.txt'
html_template = 'publisher/email/course_run/mark_as_reviewed.html'
......@@ -330,7 +339,7 @@ def send_email_to_publisher(course_run, user):
'sender_team': sender_team,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
)
}
......@@ -349,11 +358,12 @@ def send_email_to_publisher(course_run, user):
logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id)
def send_email_preview_accepted(course_run):
def send_email_preview_accepted(course_run, site):
""" Send email for preview approved to publisher and project coordinator.
Arguments:
course_run (Object): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/preview_accepted.txt'
html_template = 'publisher/email/course_run/preview_accepted.html'
......@@ -382,10 +392,10 @@ def send_email_preview_accepted(course_run):
'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
),
'course_page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
host=site.domain.strip('/'), path=course_page_path
)
}
template = get_template(txt_template)
......@@ -406,11 +416,12 @@ def send_email_preview_accepted(course_run):
raise Exception(message)
def send_email_preview_page_is_available(course_run):
def send_email_preview_page_is_available(course_run, site):
""" Send email for course preview available to course team.
Arguments:
course_run (Object): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/preview_available.txt'
html_template = 'publisher/email/course_run/preview_available.html'
......@@ -436,10 +447,10 @@ def send_email_preview_page_is_available(course_run):
'preview_link': course_run.preview_url,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
),
'course_page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
host=site.domain.strip('/'), path=course_page_path
),
'platform_name': settings.PLATFORM_NAME
}
......@@ -462,11 +473,12 @@ def send_email_preview_page_is_available(course_run):
raise Exception(error_message)
def send_course_run_published_email(course_run):
def send_course_run_published_email(course_run, site):
""" Send email when course run is published by publisher.
Arguments:
course_run (Object): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/published.txt'
html_template = 'publisher/email/course_run/published.html'
......@@ -492,10 +504,10 @@ def send_course_run_published_email(course_run):
'recipient_name': course_team_user.get_full_name() or course_team_user.username,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
),
'course_page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
host=site.domain.strip('/'), path=course_page_path
),
'platform_name': settings.PLATFORM_NAME,
}
......@@ -518,12 +530,13 @@ def send_course_run_published_email(course_run):
raise Exception(error_message)
def send_change_role_assignment_email(course_role, former_user):
def send_change_role_assignment_email(course_role, former_user, site):
""" Send email for role assignment changed.
Arguments:
course_role (Object): CourseUserRole object
former_user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/role_assignment_changed.txt'
html_template = 'publisher/email/role_assignment_changed.html'
......@@ -549,7 +562,7 @@ def send_change_role_assignment_email(course_role, former_user):
'current_user_name': course_role.user.get_full_name() or course_role.user.username,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'course_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=page_path
host=site.domain.strip('/'), path=page_path
),
'platform_name': settings.PLATFORM_NAME,
}
......@@ -572,11 +585,12 @@ def send_change_role_assignment_email(course_role, former_user):
raise Exception(error_message)
def send_email_for_seo_review(course):
def send_email_for_seo_review(course, site):
""" Send email when course is submitted for seo review.
Arguments:
course (Object): Course object
site (Site): Current site
"""
txt_template = 'publisher/email/course/seo_review.txt'
html_template = 'publisher/email/course/seo_review.html'
......@@ -597,7 +611,7 @@ def send_email_for_seo_review(course):
'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email,
'course_page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
host=site.domain.strip('/'), path=course_page_path
)
}
......@@ -615,11 +629,12 @@ def send_email_for_seo_review(course):
logger.exception('Failed to send email notifications for legal review requested of course %s', course.id)
def send_email_for_published_course_run_editing(course_run):
def send_email_for_published_course_run_editing(course_run, site):
""" Send email when published course-run is edited.
Arguments:
course-run (Object): Course Run object
site (Site): Current site
"""
try:
course = course_run.course
......@@ -644,7 +659,7 @@ def send_email_for_published_course_run_editing(course_run):
'recipient_name': publisher_user.get_full_name() or publisher_user.username,
'contact_us_email': course.project_coordinator.email,
'course_run_page_url': 'https://{host}{path}'.format(
host=Site.objects.get_current().domain.strip('/'), path=object_path
host=site.domain.strip('/'), path=object_path
),
'course_run_number': course_key.run,
}
......
......@@ -617,7 +617,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
# TODO: send email etc.
pass
def change_state(self, state, user):
def change_state(self, state, user, site=None):
"""
Change course workflow state and ownership also send emails if required.
"""
......@@ -632,12 +632,12 @@ class CourseState(TimeStampedModel, ChangedByMixin):
elif user_role.role == PublisherUserRole.CourseTeam:
self.change_owner_role(PublisherUserRole.MarketingReviewer)
if is_notifications_enabled:
emails.send_email_for_seo_review(self.course)
emails.send_email_for_seo_review(self.course, site)
self.review()
if is_notifications_enabled:
emails.send_email_for_send_for_review(self.course, user)
emails.send_email_for_send_for_review(self.course, user, site)
elif state == CourseStateChoices.Approved:
user_role = self.course.course_user_roles.get(user=user)
......@@ -646,7 +646,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
self.approved()
if is_notifications_enabled:
emails.send_email_for_mark_as_reviewed(self.course, user)
emails.send_email_for_mark_as_reviewed(self.course, user, site)
self.save()
......@@ -744,10 +744,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
pass
@transition(field=name, source=CourseRunStateChoices.Approved, target=CourseRunStateChoices.Published)
def published(self):
emails.send_course_run_published_email(self.course_run)
def published(self, site):
emails.send_course_run_published_email(self.course_run, site)
def change_state(self, state, user):
def change_state(self, state, user, site=None):
"""
Change course run workflow state and ownership also send emails if required.
"""
......@@ -763,7 +763,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.review()
if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_send_for_review_course_run(self.course_run, user)
emails.send_email_for_send_for_review_course_run(self.course_run, user, site)
elif state == CourseRunStateChoices.Approved:
user_role = self.course_run.course.course_user_roles.get(user=user)
......@@ -772,11 +772,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.approved()
if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user)
emails.send_email_to_publisher(self.course_run, user)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user, site)
emails.send_email_to_publisher(self.course_run, user, site)
elif state == CourseRunStateChoices.Published:
self.published()
self.published(site)
self.save()
......
......@@ -4,6 +4,7 @@ from django.test import TestCase
from django.urls import reverse
from guardian.shortcuts import get_group_perms
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory
from course_discovery.apps.publisher.choices import PublisherUserRole
......@@ -18,7 +19,7 @@ USER_PASSWORD = 'password'
# pylint: disable=no-member
class AdminTests(TestCase):
class AdminTests(SiteMixin, TestCase):
""" Tests Admin page."""
def setUp(self):
......@@ -81,7 +82,7 @@ class AdminTests(TestCase):
self.assertEqual(response.status_code, 200)
class OrganizationExtensionAdminTests(TestCase):
class OrganizationExtensionAdminTests(SiteMixin, TestCase):
""" Tests for OrganizationExtensionAdmin."""
def setUp(self):
......@@ -134,7 +135,7 @@ class OrganizationExtensionAdminTests(TestCase):
@ddt.ddt
class OrganizationUserRoleAdminTests(TestCase):
class OrganizationUserRoleAdminTests(SiteMixin, TestCase):
""" Tests for OrganizationUserRoleAdmin."""
def setUp(self):
......
......@@ -2,13 +2,13 @@
import mock
from django.contrib.auth.models import Group
from django.contrib.sites.models import Site
from django.core import mail
from django.test import TestCase
from django.urls import reverse
from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests import toggle_switch
......@@ -21,7 +21,7 @@ from course_discovery.apps.publisher.tests import factories
from course_discovery.apps.publisher.tests.factories import UserAttributeFactory
class StudioInstanceCreatedEmailTests(TestCase):
class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
"""
Tests for the studio instance created email functionality.
"""
......@@ -50,14 +50,14 @@ class StudioInstanceCreatedEmailTests(TestCase):
""" Verify that emails failure raise exception."""
with self.assertRaises(Exception) as ex:
emails.send_email_for_studio_instance_created(self.course_run)
emails.send_email_for_studio_instance_created(self.course_run, self.site)
error_message = 'Failed to send email notifications for course_run [{}]'.format(self.course_run.id)
self.assertEqual(ex.message, error_message)
def test_email_sent_successfully(self):
""" Verify that emails sent successfully for studio instance created."""
emails.send_email_for_studio_instance_created(self.course_run)
emails.send_email_for_studio_instance_created(self.course_run, self.site)
course_key = CourseKey.from_string(self.course_run.lms_course_id)
self.assert_email_sent(
reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id}),
......@@ -76,7 +76,7 @@ class StudioInstanceCreatedEmailTests(TestCase):
body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
self.assertIn('Enter course run content in Studio.', body)
self.assertIn('Thanks', body)
......@@ -89,7 +89,7 @@ class StudioInstanceCreatedEmailTests(TestCase):
)
class CourseCreatedEmailTests(TestCase):
class CourseCreatedEmailTests(SiteMixin, TestCase):
""" Tests for the new course created email functionality. """
def setUp(self):
......@@ -116,7 +116,7 @@ class CourseCreatedEmailTests(TestCase):
""" Verify that emails failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_course_creation(self.course_run.course, self.course_run)
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site)
l.check(
(
emails.logger.name,
......@@ -130,7 +130,7 @@ class CourseCreatedEmailTests(TestCase):
def test_email_sent_successfully(self):
""" Verify that studio instance request email sent successfully."""
emails.send_email_for_course_creation(self.course_run.course, self.course_run)
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site)
subject = 'Studio URL requested: {title}'.format(title=self.course_run.course.title)
self.assert_email_sent(subject)
......@@ -151,12 +151,12 @@ class CourseCreatedEmailTests(TestCase):
user_attribute = UserAttributes.objects.get(user=self.user)
user_attribute.enable_email_notification = False
user_attribute.save()
emails.send_email_for_course_creation(self.course_run.course, self.course_run)
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site)
self.assertEqual(len(mail.outbox), 0)
class SendForReviewEmailTests(TestCase):
class SendForReviewEmailTests(SiteMixin, TestCase):
""" Tests for the send for review email functionality. """
def setUp(self):
......@@ -168,7 +168,7 @@ class SendForReviewEmailTests(TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_send_for_review(self.course_state.course, self.user)
emails.send_email_for_send_for_review(self.course_state.course, self.user, self.site)
l.check(
(
emails.logger.name,
......@@ -180,7 +180,7 @@ class SendForReviewEmailTests(TestCase):
)
class CourseMarkAsReviewedEmailTests(TestCase):
class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase):
""" Tests for the mark as reviewed email functionality. """
def setUp(self):
......@@ -192,7 +192,7 @@ class CourseMarkAsReviewedEmailTests(TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_mark_as_reviewed(self.course_state.course, self.user)
emails.send_email_for_mark_as_reviewed(self.course_state.course, self.user, self.site)
l.check(
(
emails.logger.name,
......@@ -204,7 +204,7 @@ class CourseMarkAsReviewedEmailTests(TestCase):
)
class CourseRunSendForReviewEmailTests(TestCase):
class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
""" Tests for the CourseRun send for review email functionality. """
def setUp(self):
......@@ -238,7 +238,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user, self.site)
subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run)
self.assert_email_sent(subject, self.user_2)
......@@ -247,7 +247,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user_2)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user_2, self.site)
subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run)
self.assert_email_sent(subject, self.user)
......@@ -255,7 +255,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_send_for_review_course_run(self.course_run, self.user)
emails.send_email_for_send_for_review_course_run(self.course_run, self.user, self.site)
l.check(
(
emails.logger.name,
......@@ -273,12 +273,12 @@ class CourseRunSendForReviewEmailTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('View this course run in Publisher to review the changes or suggest edits.', body)
class CourseRunMarkAsReviewedEmailTests(TestCase):
class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
""" Tests for the CourseRun mark as reviewed email functionality. """
def setUp(self):
......@@ -311,7 +311,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user, self.site)
self.assertEqual(len(mail.outbox), 0)
def test_email_sent_by_course_team(self):
......@@ -319,14 +319,14 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user_2)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user_2, self.site)
self.assert_email_sent(self.user)
def test_email_mark_as_reviewed_with_error(self):
""" Verify that email failure log error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, self.user)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, self.user, self.site)
l.check(
(
emails.logger.name,
......@@ -342,7 +342,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_to_publisher(self.course_run_state.course_run, self.user)
emails.send_email_to_publisher(self.course_run_state.course_run, self.user, self.site)
self.assert_email_sent(self.user_3)
def test_email_to_publisher_with_error(self):
......@@ -350,7 +350,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with LogCapture(emails.logger.name) as l:
emails.send_email_to_publisher(self.course_run, self.user_3)
emails.send_email_to_publisher(self.course_run, self.user_3, self.site)
l.check(
(
emails.logger.name,
......@@ -375,12 +375,12 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('The review for this course run is complete.', body)
class CourseRunPreviewEmailTests(TestCase):
class CourseRunPreviewEmailTests(SiteMixin, TestCase):
"""
Tests for the course preview email functionality.
"""
......@@ -414,7 +414,7 @@ class CourseRunPreviewEmailTests(TestCase):
lms_course_id = 'course-v1:edX+DemoX+Demo_Course'
self.run_state.course_run.lms_course_id = lms_course_id
emails.send_email_preview_accepted(self.run_state.course_run)
emails.send_email_preview_accepted(self.run_state.course_run, self.site)
course_key = CourseKey.from_string(lms_course_id)
subject = 'Publication requested: {course_name} {run_number}'.format(
......@@ -426,7 +426,7 @@ class CourseRunPreviewEmailTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.run_state.course_run.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('You can now publish this About page.', body)
......@@ -440,7 +440,7 @@ class CourseRunPreviewEmailTests(TestCase):
with self.assertRaises(Exception) as ex:
self.assertEqual(str(ex.exception), message)
with LogCapture(emails.logger.name) as l:
emails.send_email_preview_accepted(self.run_state.course_run)
emails.send_email_preview_accepted(self.run_state.course_run, self.site)
l.check(
(
emails.logger.name,
......@@ -457,7 +457,7 @@ class CourseRunPreviewEmailTests(TestCase):
course_run.lms_course_id = 'course-v1:testX+testX1.0+2017T1'
course_run.save()
emails.send_email_preview_page_is_available(course_run)
emails.send_email_preview_page_is_available(course_run, self.site)
course_key = CourseKey.from_string(course_run.lms_course_id)
subject = 'Review requested: Preview for {course_name} {run_number}'.format(
......@@ -469,7 +469,7 @@ class CourseRunPreviewEmailTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': course_run.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('A preview is now available for the', body)
......@@ -477,7 +477,7 @@ class CourseRunPreviewEmailTests(TestCase):
""" Verify that exception raised on email failure."""
with self.assertRaises(Exception) as ex:
emails.send_email_preview_page_is_available(self.run_state.course_run)
emails.send_email_preview_page_is_available(self.run_state.course_run, self.site)
error_message = 'Failed to send email notifications for preview available of course-run {}'.format(
self.run_state.course_run.id
)
......@@ -486,19 +486,19 @@ class CourseRunPreviewEmailTests(TestCase):
def test_preview_available_email_with_notification_disabled(self):
""" Verify that email not sent if notification disabled by user."""
factories.UserAttributeFactory(user=self.course.course_team_admin, enable_email_notification=False)
emails.send_email_preview_page_is_available(self.run_state.course_run)
emails.send_email_preview_page_is_available(self.run_state.course_run, self.site)
self.assertEqual(len(mail.outbox), 0)
def test_preview_accepted_email_with_notification_disabled(self):
""" Verify that preview accepted email not sent if notification disabled by user."""
factories.UserAttributeFactory(user=self.course.publisher, enable_email_notification=False)
emails.send_email_preview_accepted(self.run_state.course_run)
emails.send_email_preview_accepted(self.run_state.course_run, self.site)
self.assertEqual(len(mail.outbox), 0)
class CourseRunPublishedEmailTests(TestCase):
class CourseRunPublishedEmailTests(SiteMixin, TestCase):
"""
Tests for course run published email functionality.
"""
......@@ -527,7 +527,7 @@ class CourseRunPublishedEmailTests(TestCase):
"""
self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2'
self.course_run.save()
emails.send_course_run_published_email(self.course_run)
emails.send_course_run_published_email(self.course_run, self.site)
course_key = CourseKey.from_string(self.course_run.lms_course_id)
subject = 'Publication complete: About page for {course_name} {run_number}'.format(
......@@ -550,11 +550,11 @@ class CourseRunPublishedEmailTests(TestCase):
)
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with self.assertRaises(Exception) as ex:
emails.send_course_run_published_email(self.course_run)
emails.send_course_run_published_email(self.course_run, self.site)
self.assertEqual(str(ex.exception), message)
class CourseChangeRoleAssignmentEmailTests(TestCase):
class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
"""
Tests email functionality for course role assignment changed.
"""
......@@ -575,7 +575,7 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
"""
Verify that course role assignment chnage email functionality works fine.
"""
emails.send_change_role_assignment_email(self.marketing_role, self.user)
emails.send_change_role_assignment_email(self.marketing_role, self.user, self.site)
expected_subject = '{role_name} changed for {course_title}'.format(
role_name=self.marketing_role.get_role_display().lower(),
course_title=self.course.title
......@@ -589,7 +589,7 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('has changed.', body)
......@@ -603,11 +603,11 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
)
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with self.assertRaises(Exception) as ex:
emails.send_change_role_assignment_email(self.marketing_role, self.user)
emails.send_change_role_assignment_email(self.marketing_role, self.user, self.site)
self.assertEqual(str(ex.exception), message)
class SEOReviewEmailTests(TestCase):
class SEOReviewEmailTests(SiteMixin, TestCase):
""" Tests for the seo review email functionality. """
def setUp(self):
......@@ -626,7 +626,7 @@ class SEOReviewEmailTests(TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_seo_review(self.course)
emails.send_email_for_seo_review(self.course, self.site)
l.check(
(
emails.logger.name,
......@@ -642,7 +642,7 @@ class SEOReviewEmailTests(TestCase):
Verify that seo review email functionality works fine.
"""
factories.CourseUserRoleFactory(course=self.course, role=PublisherUserRole.ProjectCoordinator)
emails.send_email_for_seo_review(self.course)
emails.send_email_for_seo_review(self.course, self.site)
expected_subject = 'Legal review requested: {title}'.format(title=self.course.title)
self.assertEqual(len(mail.outbox), 1)
......@@ -652,7 +652,7 @@ class SEOReviewEmailTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('determine OFAC status', body)
......@@ -671,7 +671,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
)
self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2'
self.course_run.save()
emails.send_email_for_published_course_run_editing(self.course_run)
emails.send_email_for_published_course_run_editing(self.course_run, self.site)
course_key = CourseKey.from_string(self.course_run.lms_course_id)
......@@ -692,7 +692,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_published_course_run_editing(self.course_run)
emails.send_email_for_published_course_run_editing(self.course_run, self.site)
l.check(
(
emails.logger.name,
......
......@@ -6,7 +6,7 @@ from django.urls import reverse
from django_fsm import TransitionNotAllowed
from guardian.shortcuts import assign_perm
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory
from course_discovery.apps.ietf_language_tags.models import LanguageTag
......@@ -525,6 +525,8 @@ class CourseStateTests(TestCase):
def setUp(self):
super(CourseStateTests, self).setUp()
self.site = SiteFactory()
self.partner = PartnerFactory(site=self.site)
self.course = self.course_state.course
self.course.image = make_image_file('test_banner.jpg')
self.course.save()
......@@ -548,7 +550,7 @@ class CourseStateTests(TestCase):
"""
self.assertNotEqual(self.course_state.name, state)
self.course_state.change_state(state=state, user=self.user)
self.course_state.change_state(state=state, user=self.user, site=self.site)
self.assertEqual(self.course_state.name, state)
......@@ -561,7 +563,7 @@ class CourseStateTests(TestCase):
self.assertEqual(self.course_state.name, CourseStateChoices.Draft)
with self.assertRaises(TransitionNotAllowed):
self.course_state.change_state(state=CourseStateChoices.Review, user=self.user)
self.course_state.change_state(state=CourseStateChoices.Review, user=self.user, site=self.site)
def test_can_send_for_review(self):
"""
......@@ -673,6 +675,9 @@ class CourseRunStateTests(TestCase):
language_tag = LanguageTag(code='te-st', name='Test Language')
language_tag.save()
self.site = SiteFactory()
self.partner = PartnerFactory(site=self.site)
self.course_run.transcript_languages.add(language_tag)
self.course_run.language = language_tag
self.course_run.is_micromasters = True
......@@ -703,7 +708,7 @@ class CourseRunStateTests(TestCase):
Verify that we can change course-run state according to workflow.
"""
self.assertNotEqual(self.course_run_state.name, state)
self.course_run_state.change_state(state=state, user=self.user)
self.course_run_state.change_state(state=state, user=self.user, site=self.site)
self.assertEqual(self.course_run_state.name, state)
def test_with_invalid_parent_course_state(self):
......
......@@ -19,11 +19,13 @@ from opaque_keys.edx.keys import CourseKey
from pytz import timezone
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, OrganizationFactory, PersonFactory
from course_discovery.apps.course_metadata.tests.factories import (CourseFactory, OrganizationFactory, PersonFactory,
SubjectFactory)
from course_discovery.apps.ietf_language_tags.models import LanguageTag
from course_discovery.apps.publisher.choices import (CourseRunStateChoices, CourseStateChoices, InternalUserRole,
PublisherUserRole)
......@@ -42,7 +44,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt
class CreateCourseViewTests(TestCase):
class CreateCourseViewTests(SiteMixin, TestCase):
""" Tests for the publisher `CreateCourseView`. """
def setUp(self):
......@@ -61,7 +63,6 @@ class CreateCourseViewTests(TestCase):
self.course = factories.CourseFactory()
self.course.organizations.add(self.organization_extension.organization)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD)
# creating default organizations roles
......@@ -269,7 +270,7 @@ class CreateCourseViewTests(TestCase):
)
class CreateCourseRunViewTests(TestCase):
class CreateCourseRunViewTests(SiteMixin, TestCase):
""" Tests for the publisher `UpdateCourseRunView`. """
def setUp(self):
......@@ -299,7 +300,6 @@ class CreateCourseRunViewTests(TestCase):
current_datetime = datetime.now(timezone('US/Central'))
self.course_run_dict['start'] = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
self.course_run_dict['end'] = (current_datetime + timedelta(days=3)).strftime('%Y-%m-%d %H:%M:%S')
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD)
def _pop_valuse_from_dict(self, data_dict, key_list):
......@@ -562,7 +562,7 @@ class CreateCourseRunViewTests(TestCase):
@ddt.ddt
class CourseRunDetailTests(TestCase):
class CourseRunDetailTests(SiteMixin, TestCase):
""" Tests for the course-run detail view. """
def setUp(self):
......@@ -763,9 +763,8 @@ class CourseRunDetailTests(TestCase):
"""
self.client.logout()
self.client.login(username=self.user.username, password=USER_PASSWORD)
site = Site.objects.get(pk=settings.SITE_ID)
comment = CommentFactory(content_object=self.course_run, user=self.user, site=site)
comment = CommentFactory(content_object=self.course_run, user=self.user, site=self.site)
response = self.client.get(self.page_url)
self.assertEqual(response.status_code, 200)
self._assert_credits_seats(response, self.wrapped_course_run.credit_seat)
......@@ -779,7 +778,7 @@ class CourseRunDetailTests(TestCase):
# test decline comment appearing on detail page also.
decline_comment = CommentFactory(
content_object=self.course_run,
user=self.user, site=site, comment_type=CommentTypeChoices.Decline_Preview
user=self.user, site=self.site, comment_type=CommentTypeChoices.Decline_Preview
)
response = self.client.get(self.page_url)
self.assertContains(response, decline_comment.comment)
......@@ -1233,12 +1232,12 @@ class CourseRunDetailTests(TestCase):
# pylint: disable=attribute-defined-outside-init
@ddt.ddt
class DashboardTests(TestCase):
class DashboardTests(SiteMixin, TestCase):
""" Tests for the `Dashboard`. """
def setUp(self):
super(DashboardTests, self).setUp()
Site.objects.exclude(id=self.site.id).delete()
self.group_internal = Group.objects.get(name=INTERNAL_USER_GROUP_NAME)
self.group_project_coordinator = Group.objects.get(name=PROJECT_COORDINATOR_GROUP_NAME)
self.group_reviewer = Group.objects.get(name=REVIEWER_GROUP_NAME)
......@@ -1277,7 +1276,18 @@ class DashboardTests(TestCase):
def _create_course_assign_role(self, state, user, role):
""" Create course-run-state, course-user-role and return course-run. """
course_run_state = factories.CourseRunStateFactory(name=state, owner_role=role)
course = factories.CourseFactory(
primary_subject=SubjectFactory(partner=self.partner),
secondary_subject=SubjectFactory(partner=self.partner),
tertiary_subject=SubjectFactory(partner=self.partner)
)
course_run = factories.CourseRunFactory(course=course)
course_run_state = factories.CourseRunStateFactory(
name=state,
owner_role=role,
course_run=course_run
)
factories.CourseUserRoleFactory(course=course_run_state.course_run.course, role=role, user=user)
return course_run_state.course_run
......@@ -1301,7 +1311,7 @@ class DashboardTests(TestCase):
self.client.logout()
self.client.login(username=UserFactory(), password=USER_PASSWORD)
response = self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=11
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=12
)
self._assert_tabs_with_roles(response)
......@@ -1309,7 +1319,7 @@ class DashboardTests(TestCase):
def test_with_internal_group(self, tab):
""" Verify that internal user can see courses assigned to the groups. """
response = self.assert_dashboard_response(
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=23
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=24
)
self.assertContains(response, '<li role="tab" id="tab-{tab}" class="tab"'.format(tab=tab))
......@@ -1324,7 +1334,7 @@ class DashboardTests(TestCase):
self.course_run_1.course.organizations.add(self.organization_extension.organization)
response = self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=11
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=12
)
self._assert_tabs_with_roles(response)
......@@ -1349,14 +1359,14 @@ class DashboardTests(TestCase):
)
response = self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=2, preview_count=1, queries_executed=21
studio_count=0, published_count=0, progress_count=2, preview_count=1, queries_executed=22
)
self._assert_tabs_with_roles(response)
def test_studio_request_course_runs_as_pc(self):
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
response = self.assert_dashboard_response(
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=23
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=24
)
self._assert_tabs_with_roles(response)
......@@ -1364,7 +1374,7 @@ class DashboardTests(TestCase):
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
self.user1.groups.remove(self.group_project_coordinator)
response = self.assert_dashboard_response(
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=20
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=21
)
self._assert_tabs_with_roles(response)
......@@ -1375,7 +1385,7 @@ class DashboardTests(TestCase):
self.course_run_2.lms_course_id = 'test-2'
self.course_run_2.save()
response = self.assert_dashboard_response(
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=21
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=22
)
self.assertContains(response, 'No courses are currently ready for a Studio URL.')
......@@ -1384,7 +1394,7 @@ class DashboardTests(TestCase):
self.course_run_3.course_run_state.name = CourseRunStateChoices.Draft
self.course_run_3.course_run_state.save()
response = self.assert_dashboard_response(
studio_count=3, published_count=0, progress_count=3, preview_count=1, queries_executed=24
studio_count=3, published_count=0, progress_count=3, preview_count=1, queries_executed=25
)
self.assertContains(response, 'No About pages have been published yet')
self._assert_tabs_with_roles(response)
......@@ -1392,7 +1402,7 @@ class DashboardTests(TestCase):
def test_published_course_runs(self):
""" Verify that published tab loads course runs list. """
response = self.assert_dashboard_response(
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=23
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=24
)
self.assertContains(response, self.table_class.format(id='published'))
self.assertContains(response, 'About pages for the following course runs have been published in the')
......@@ -1410,7 +1420,7 @@ class DashboardTests(TestCase):
# Verify that user cannot see any published course run
self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=15
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=16
)
# assign user course role
......@@ -1420,7 +1430,7 @@ class DashboardTests(TestCase):
# Verify that user can see 1 published course run
response = self.assert_dashboard_response(
studio_count=0, published_count=1, progress_count=0, preview_count=0, queries_executed=16
studio_count=0, published_count=1, progress_count=0, preview_count=0, queries_executed=17
)
self._assert_tabs_with_roles(response)
......@@ -1434,14 +1444,14 @@ class DashboardTests(TestCase):
publisher_admin.groups.add(Group.objects.get(name=ADMIN_GROUP_NAME))
self.client.login(username=publisher_admin.username, password=USER_PASSWORD)
response = self.assert_dashboard_response(
studio_count=4, published_count=1, progress_count=3, preview_count=1, queries_executed=20
studio_count=4, published_count=1, progress_count=3, preview_count=1, queries_executed=21
)
self._assert_tabs_with_roles(response)
def test_with_preview_ready_course_runs(self):
""" Verify that preview ready tabs loads the course runs list. """
response = self.assert_dashboard_response(
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=23
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=24
)
self.assertContains(response, self.table_class.format(id='preview'))
self.assertContains(response, 'About page previews for the following course runs are available for course team')
......@@ -1453,7 +1463,7 @@ class DashboardTests(TestCase):
self.course_run_2.course_run_state.name = CourseRunStateChoices.Draft
self.course_run_2.course_run_state.save()
response = self.assert_dashboard_response(
studio_count=2, preview_count=0, progress_count=3, published_count=1, queries_executed=22
studio_count=2, preview_count=0, progress_count=3, published_count=1, queries_executed=23
)
self._assert_tabs_with_roles(response)
......@@ -1462,7 +1472,7 @@ class DashboardTests(TestCase):
preview url is added or not.
"""
response = self.assert_dashboard_response(
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=23
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=24
)
self._assert_tabs_with_roles(response)
......@@ -1470,14 +1480,14 @@ class DashboardTests(TestCase):
self.course_run_2.preview_url = None
self.course_run_2.save()
response = self.assert_dashboard_response(
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=23
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=24
)
self._assert_tabs_with_roles(response)
def test_with_in_progress_course_runs(self):
""" Verify that in progress tabs loads the course runs list. """
response = self.assert_dashboard_response(
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=23
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=24
)
self.assertContains(response, self.table_class.format(id='in-progress'))
self._assert_tabs_with_roles(response)
......@@ -1513,7 +1523,7 @@ class DashboardTests(TestCase):
self.client.logout()
self.client.login(username=pc_user.username, password=USER_PASSWORD)
with self.assertNumQueries(11):
with self.assertNumQueries(12):
response = self.client.get(self.page_url)
for tab in ['progress', 'preview', 'studio', 'published']:
......@@ -1523,7 +1533,7 @@ class DashboardTests(TestCase):
"""
Verify that site_name is available in context.
"""
with self.assertNumQueries(23):
with self.assertNumQueries(24):
response = self.client.get(self.page_url)
site = Site.objects.first()
self.assertEqual(response.context['site_name'], site.name)
......@@ -1542,13 +1552,12 @@ class DashboardTests(TestCase):
course_run.course_run_state.owner_role = PublisherUserRole.CourseTeam
course_run.course_run_state.save()
with self.assertNumQueries(25):
with self.assertNumQueries(26):
response = self.client.get(self.page_url)
site = Site.objects.first()
self._assert_filter_counts(response, 'All', 3)
self._assert_filter_counts(response, 'With Course Team', 2)
self._assert_filter_counts(response, 'With {site_name}'.format(site_name=site.name), 1)
self._assert_filter_counts(response, 'With {site_name}'.format(site_name=self.site.name), 1)
def _assert_filter_counts(self, response, expected_label, count):
"""
......@@ -1559,7 +1568,7 @@ class DashboardTests(TestCase):
self.assertContains(response, expected_count, count=1)
class ToggleEmailNotificationTests(TestCase):
class ToggleEmailNotificationTests(SiteMixin, TestCase):
""" Tests for `ToggleEmailNotification` view. """
def setUp(self):
......@@ -1592,7 +1601,7 @@ class ToggleEmailNotificationTests(TestCase):
self.assertEqual(is_email_notification_enabled(user), is_enabled)
class CourseListViewTests(TestCase):
class CourseListViewTests(SiteMixin, TestCase):
""" Tests for `CourseListView` """
def setUp(self):
......@@ -1606,12 +1615,12 @@ class CourseListViewTests(TestCase):
def test_courses_with_no_courses(self):
""" Verify that user cannot see any course on course list page. """
self.assert_course_list_page(course_count=0, queries_executed=8)
self.assert_course_list_page(course_count=0, queries_executed=9)
def test_courses_with_admin(self):
""" Verify that admin user can see all courses on course list page. """
self.user.groups.add(Group.objects.get(name=ADMIN_GROUP_NAME))
self.assert_course_list_page(course_count=10, queries_executed=31)
self.assert_course_list_page(course_count=10, queries_executed=32)
def test_courses_with_course_user_role(self):
""" Verify that internal user can see course on course list page. """
......@@ -1619,7 +1628,7 @@ class CourseListViewTests(TestCase):
for course in self.courses:
factories.CourseUserRoleFactory(course=course, user=self.user, role=InternalUserRole.Publisher)
self.assert_course_list_page(course_count=10, queries_executed=32)
self.assert_course_list_page(course_count=10, queries_executed=33)
def test_courses_with_permission(self):
""" Verify that user can see course with permission on course list page. """
......@@ -1630,7 +1639,7 @@ class CourseListViewTests(TestCase):
course.organizations.add(organization_extension.organization)
assign_perm(OrganizationExtension.VIEW_COURSE, organization_extension.group, organization_extension)
self.assert_course_list_page(course_count=10, queries_executed=64)
self.assert_course_list_page(course_count=10, queries_executed=65)
def assert_course_list_page(self, course_count, queries_executed):
""" Dry method to assert course list page content. """
......@@ -1657,7 +1666,7 @@ class CourseListViewTests(TestCase):
toggle_switch('publisher_hide_features_for_pilot', True)
with self.assertNumQueries(17):
with self.assertNumQueries(18):
response = self.client.get(self.courses_url)
self.assertNotContains(response, 'Edit')
......@@ -1676,13 +1685,13 @@ class CourseListViewTests(TestCase):
toggle_switch('publisher_hide_features_for_pilot', False)
with self.assertNumQueries(21):
with self.assertNumQueries(22):
response = self.client.get(self.courses_url)
self.assertContains(response, 'Edit')
class CourseDetailViewTests(TestCase):
class CourseDetailViewTests(SiteMixin, TestCase):
""" Tests for the course detail view. """
def setUp(self):
......@@ -2114,7 +2123,7 @@ class CourseDetailViewTests(TestCase):
@ddt.ddt
class CourseEditViewTests(TestCase):
class CourseEditViewTests(SiteMixin, TestCase):
""" Tests for the course edit view. """
def setUp(self):
......@@ -2532,7 +2541,7 @@ class CourseEditViewTests(TestCase):
@ddt.ddt
class CourseRunEditViewTests(TestCase):
class CourseRunEditViewTests(SiteMixin, TestCase):
""" Tests for the course run edit view. """
def setUp(self):
......@@ -2550,7 +2559,6 @@ class CourseRunEditViewTests(TestCase):
self.seat = factories.SeatFactory(course_run=self.course_run, type=Seat.VERIFIED, price=2)
self.course.organizations.add(self.organization_extension.organization)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD)
current_datetime = datetime.now(timezone('US/Central'))
self.start_date_time = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
......@@ -2833,7 +2841,7 @@ class CourseRunEditViewTests(TestCase):
body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
def test_studio_instance_with_course_team(self):
......@@ -3062,7 +3070,7 @@ class CourseRunEditViewTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
class CourseRevisionViewTests(TestCase):
class CourseRevisionViewTests(SiteMixin, TestCase):
""" Tests for CourseReview"""
def setUp(self):
......@@ -3114,7 +3122,7 @@ class CourseRevisionViewTests(TestCase):
return self.client.get(path=revision_path)
class CreateRunFromDashboardViewTests(TestCase):
class CreateRunFromDashboardViewTests(SiteMixin, TestCase):
""" Tests for the publisher `CreateRunFromDashboardView`. """
def setUp(self):
......@@ -3214,7 +3222,7 @@ class CreateRunFromDashboardViewTests(TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
class CreateAdminImportCourseTest(TestCase):
class CreateAdminImportCourseTest(SiteMixin, TestCase):
""" Tests for the publisher `CreateAdminImportCourse`. """
def setUp(self):
......
......@@ -394,14 +394,16 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView):
if latest_run and latest_run.course_run_state.name == CourseRunStateChoices.Published:
# If latest run of this course is published send an email to Publisher and don't change state.
send_email_for_published_course_run_editing(latest_run)
send_email_for_published_course_run_editing(latest_run, self.request.site)
else:
user_role = self.object.course_user_roles.get(user=user)
# Change course state to draft if marketing not yet reviewed or
# if marketing person updating the course.
if not self.object.course_state.marketing_reviewed or user_role.role == PublisherUserRole.MarketingReviewer:
if self.object.course_state.name != CourseStateChoices.Draft:
self.object.course_state.change_state(state=CourseStateChoices.Draft, user=user)
self.object.course_state.change_state(
state=CourseStateChoices.Draft, user=user, site=self.request.site
)
# Change ownership if user role not equal to owner role.
if self.object.course_state.owner_role != user_role.role:
......@@ -599,7 +601,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView):
)
messages.success(request, success_msg)
emails.send_email_for_course_creation(parent_course, course_run)
emails.send_email_for_course_creation(parent_course, course_run, request.site)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as error: # pylint: disable=broad-except
# pylint: disable=no-member
......@@ -740,10 +742,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state = course_run.course_run_state
if course_run_state.name not in immutable_states:
course_run_state.change_state(state=CourseStateChoices.Draft, user=user)
course_run_state.change_state(state=CourseStateChoices.Draft, user=user, site=request.site)
if course_run.lms_course_id and lms_course_id != course_run.lms_course_id:
emails.send_email_for_studio_instance_created(course_run)
emails.send_email_for_studio_instance_created(course_run, site=request.site)
# pylint: disable=no-member
messages.success(request, _('Course run updated successfully.'))
......@@ -757,7 +759,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state.change_owner_role(user_role)
if CourseRunStateChoices.Published == course_run_state.name:
send_email_for_published_course_run_editing(course_run)
send_email_for_published_course_run_editing(course_run, request.site)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as e: # pylint: disable=broad-except
......
......@@ -3,6 +3,7 @@ import json
from django.test import TestCase
from rest_framework.reverse import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE
from course_discovery.apps.publisher.tests.factories import CourseRunFactory
......@@ -11,7 +12,7 @@ from course_discovery.apps.publisher_comments.models import Comments
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class PostCommentTests(TestCase):
class PostCommentTests(SiteMixin, TestCase):
def generate_data(self, obj):
"""Generate data for the form."""
......@@ -39,7 +40,7 @@ class PostCommentTests(TestCase):
self.assertEqual(comment.user_email, generated_data['email'])
class UpdateCommentTests(TestCase):
class UpdateCommentTests(SiteMixin, TestCase):
def setUp(self):
super(UpdateCommentTests, self).setUp()
......
from django.conf import settings
from django.contrib.sites.models import Site
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import factories
from course_discovery.apps.publisher_comments.forms import CommentsAdminForm
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class AdminTests(TestCase):
class AdminTests(SiteMixin, TestCase):
""" Tests Admin page and customize form."""
def setUp(self):
super(AdminTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.course = factories.CourseFactory()
self.comment = CommentFactory(content_object=self.course, user=self.user, site=self.site)
......
import ddt
import mock
from django.conf import settings
from django.contrib.sites.models import Site
from django.core import mail
from django.test import TestCase
from django.urls import reverse
from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.publisher.choices import PublisherUserRole
......@@ -20,7 +19,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt
class CommentsEmailTests(TestCase):
class CommentsEmailTests(SiteMixin, TestCase):
""" Tests for the e-mail functionality for course, course-run and seats. """
def setUp(self):
......@@ -30,8 +29,6 @@ class CommentsEmailTests(TestCase):
self.user_2 = UserFactory()
self.user_3 = UserFactory()
self.site = Site.objects.get(pk=settings.SITE_ID)
self.organization_extension = factories.OrganizationExtensionFactory()
self.seat = factories.SeatFactory()
......
......@@ -47,6 +47,7 @@ THIRD_PARTY_APPS = [
'django_fsm',
'storages',
'django_comments',
'django_sites_extensions',
'taggit',
'taggit_autosuggest',
'taggit_serializer',
......@@ -79,6 +80,7 @@ MIDDLEWARE_CLASSES = (
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.contrib.sites.middleware.CurrentSiteMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'social_django.middleware.SocialAuthExceptionMiddleware',
'waffle.middleware.WaffleMiddleware',
......@@ -476,7 +478,9 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20
DEFAULT_PARTNER_ID = None
# See: https://docs.djangoproject.com/en/dev/ref/settings/#site-id
# edx-django-sites-extensions will fallback to this site if we cannot identify the site from the hostname.
SITE_ID = 1
COMMENTS_APP = 'course_discovery.apps.publisher_comments'
TAGGIT_CASE_INSENSITIVE = True
......
......@@ -45,3 +45,7 @@ JWT_AUTH['JWT_SECRET_KEY'] = 'course-discovery-jwt-secret-key'
LOGGING['handlers']['local'] = {'class': 'logging.NullHandler'}
PUBLISHER_FROM_EMAIL = 'test@example.com'
# Set to 0 to disable edx-django-sites-extensions to retrieve
# the site from cache and risk working with outdated information.
SITE_CACHE_TTL = 0
......@@ -32,6 +32,7 @@ dry-rest-permissions==0.1.6
edx-auth-backends==1.1.2
edx-ccx-keys==0.2.0
edx-django-release-util==0.3.1
edx-django-sites-extensions==2.3.0
edx-drf-extensions==1.2.3
edx-opaque-keys==0.3.1
edx-rest-api-client==1.6.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment