Commit eb0f4dd7 by Vedran Karacic Committed by Michael Frey

Revert "Filter API data by partner"

This reverts commit bcbd1d4f.
parent 24714f90
import logging
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db.models import QuerySet
from django.utils.translation import ugettext as _
......@@ -12,6 +13,7 @@ from guardian.shortcuts import get_objects_for_user
from rest_framework.exceptions import NotFound, PermissionDenied
from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Organization, Program
......@@ -91,7 +93,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
# Return data for the default partner, if no partner is requested
if not any(field in filters for field in ('partner', 'partner_exact')):
filters['partner'] = request.site.partner.short_code
filters['partner'] = Partner.objects.get(pk=settings.DEFAULT_PARTNER_ID).short_code
return filters
......
......@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
tags = TagListSerializerField()
@classmethod
def prefetch_queryset(cls, partner):
return Organization.objects.filter(partner=partner).select_related('partner').prefetch_related('tags')
def prefetch_queryset(cls):
return Organization.objects.all().select_related('partner').prefetch_related('tags')
class Meta(MinimalOrganizationSerializer.Meta):
fields = MinimalOrganizationSerializer.Meta.fields + (
......@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None, partner=None):
def prefetch_queryset(cls, queryset=None, course_runs=None):
# Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
queryset = queryset if queryset is not None else Course.objects.all()
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
class Meta(MinimalCourseSerializer.Meta):
......@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
programs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None, partner=None):
def prefetch_queryset(cls, queryset=None, course_runs=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
queryset = queryset if queryset is not None else Course.objects.all()
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
def get_course_runs(self, course):
......@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None, partner=None):
def prefetch_queryset(cls, queryset=None, course_runs=None):
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
"""
queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
queryset = queryset if queryset is not None else Course.objects.all()
return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items',
'prerequisites',
'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
)
def get_course_runs(self, course):
......@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all())
@classmethod
def prefetch_queryset(cls, partner):
return Program.objects.filter(partner=partner).select_related('type', 'partner').prefetch_related(
def prefetch_queryset(cls):
return Program.objects.all().select_related('type', 'partner').prefetch_related(
'excluded_course_runs',
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
......@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
applicable_seat_types = serializers.SerializerMethodField()
@classmethod
def prefetch_queryset(cls, partner):
def prefetch_queryset(cls):
"""
Prefetch the related objects that will be serialized with a `Program`.
......@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching).
"""
return Program.objects.filter(partner=partner).select_related('type', 'video', 'partner').prefetch_related(
return Program.objects.all().select_related('type', 'video', 'partner').prefetch_related(
'excluded_course_runs',
'expected_learning_items',
'faq',
......@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
'type__applicable_seat_types',
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset()),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset()),
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()),
)
......
from django.test import RequestFactory
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory
class SiteMixin(object):
def setUp(self):
super(SiteMixin, self).setUp()
domain = 'testserver.fake'
self.client = self.client_class(SERVER_NAME=domain)
self.site = SiteFactory(domain=domain)
self.partner = PartnerFactory(site=self.site)
self.request = RequestFactory(SERVER_NAME=self.site.domain).get('')
self.request.site = self.site
......@@ -21,7 +21,6 @@ from course_discovery.apps.api.serializers import (
ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer,
TypeaheadProgramSearchSerializer, VideoSerializer
)
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory
......@@ -97,7 +96,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
class MinimalCourseSerializerTests(SiteMixin, TestCase):
class MinimalCourseSerializerTests(TestCase):
serializer_class = MinimalCourseSerializer
def get_expected_data(self, course, request):
......@@ -114,8 +113,8 @@ class MinimalCourseSerializerTests(SiteMixin, TestCase):
def test_data(self):
request = make_request()
organizations = OrganizationFactory(partner=self.partner)
course = CourseFactory(authoring_organizations=[organizations], partner=self.partner)
organizations = OrganizationFactory()
course = CourseFactory(authoring_organizations=[organizations])
CourseRunFactory.create_batch(2, course=course)
serializer = self.serializer_class(course, context={'request': request})
expected = self.get_expected_data(course, request)
......@@ -178,10 +177,9 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
def setUp(self):
super().setUp()
self.request = make_request()
self.course = CourseFactory(partner=self.partner)
self.course = CourseFactory()
self.deleted_program = ProgramFactory(
courses=[self.course],
partner=self.partner,
status=ProgramStatus.Deleted
)
......
import ddt
import pytest
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import PermissionDenied
from django.test import RequestFactory, TestCase
from django.urls import reverse
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.views import api_docs_permission_denied_handler
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory
class TestApiDocs(APITestCase):
@pytest.mark.django_db
class TestApiDocs:
"""
Regression tests introduced following LEARNER-1590.
"""
path = reverse('api_docs')
def test_api_docs(self):
def test_api_docs(self, admin_client):
"""
Verify that the API docs are available to authenticated clients.
"""
user = UserFactory(is_staff=True)
self.client.login(username=user.username, password=USER_PASSWORD)
PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
response = admin_client.get(self.path)
response = self.client.get(self.path)
assert response.status_code == 200
def test_api_docs_redirect(self):
def test_api_docs_redirect(self, client):
"""
Verify that unauthenticated clients are redirected.
"""
response = self.client.get(self.path)
response = client.get(self.path)
assert response.status_code == 302
......
......@@ -4,7 +4,6 @@ import json
import responses
from django.conf import settings
from rest_framework.test import APITestCase as RestAPITestCase
from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.serializers import (
......@@ -12,7 +11,6 @@ from course_discovery.apps.api.serializers import (
CourseWithProgramsSerializer, FlattenedCourseRunWithCourseSerializer, MinimalProgramSerializer,
OrganizationSerializer, PersonSerializer, ProgramSerializer, ProgramTypeSerializer
)
from course_discovery.apps.api.tests.mixins import SiteMixin
class SerializationMixin(object):
......@@ -90,7 +88,3 @@ class OAuth2Mixin(object):
content_type='application/json',
status=status
)
class APITestCase(SiteMixin, RestAPITestCase):
pass
......@@ -7,9 +7,10 @@ import ddt
import pytz
from lxml import etree
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import AffiliateWindowSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
......@@ -45,7 +46,8 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
def test_affiliate_with_supported_seats(self):
""" Verify that endpoint returns course runs for verified and professional seats only. """
response = self.client.get(self.affiliate_url)
with self.assertNumQueries(8):
response = self.client.get(self.affiliate_url)
self.assertEqual(response.status_code, 200)
root = ET.fromstring(response.content)
......@@ -128,7 +130,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
# Superusers can view all catalogs
self.client.force_authenticate(superuser)
with self.assertNumQueries(5):
with self.assertNumQueries(4):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
......
......@@ -8,9 +8,10 @@ import pytz
import responses
from django.contrib.auth import get_user_model
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, OAuth2Mixin, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import OAuth2Mixin, SerializationMixin
from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory
......@@ -30,7 +31,6 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
def setUp(self):
super(CatalogViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.force_authenticate(self.user)
self.catalog = CatalogFactory(query='title:abc*')
enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=30)
......@@ -185,7 +185,8 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# Any course appearing in the response must have at least one serialized run.
assert len(response.data['results'][0]['course_runs']) > 0
else:
response = self.client.get(url)
with self.assertNumQueries(2):
response = self.client.get(url)
assert response.status_code == 200
assert response.data['results'] == []
......@@ -217,7 +218,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id})
with self.assertNumQueries(18):
with self.assertNumQueries(17):
response = self.client.get(url)
course_run = self.serialize_catalog_flat_course_run(self.course_run)
......
......@@ -4,16 +4,19 @@ import urllib
import ddt
import pytz
from django.conf import settings
from django.db.models.functions import Lower
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory
from rest_framework.test import APIRequestFactory, APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory, SeatFactory
from course_discovery.apps.course_metadata.tests.factories import (
CourseRunFactory, PartnerFactory, ProgramFactory, SeatFactory
)
@ddt.ddt
......@@ -22,6 +25,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
super(CourseRunViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
self.course_run = CourseRunFactory(course__partner=self.partner)
self.course_run_2 = CourseRunFactory(course__partner=self.partner)
self.refresh_index()
......@@ -163,6 +170,15 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted)
def test_list_query_invalid_partner(self):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query = 'title:Some random title'
url = '{root}?q={query}&partner={partner}'.format(root=reverse('api:v1:course_run-list'), query=query,
partner='foo')
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def assert_list_results(self, url, expected, extra_context=None):
expected = sorted(expected, key=lambda course_run: course_run.key.lower())
response = self.client.get(url)
......@@ -252,6 +268,18 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
}
)
def test_contains_single_course_run_invalid_partner(self):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs = urllib.parse.urlencode({
'query': 'id:course*',
'course_run_ids': self.course_run.key,
'partner': 'foo'
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url)
assert response.status_code == 400
def test_contains_multiple_course_runs(self):
qs = urllib.parse.urlencode({
'query': 'id:course*',
......
......@@ -4,8 +4,9 @@ import ddt
import pytz
from django.db.models.functions import Lower
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import Course
......@@ -21,15 +22,14 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
def setUp(self):
super(CourseViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.course = CourseFactory(partner=self.partner)
self.course = CourseFactory()
def test_get(self):
""" Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(20):
with self.assertNumQueries(18):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course))
......@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(13):
with self.assertNumQueries(11):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), [])
......@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1'
with self.assertNumQueries(24):
with self.assertNumQueries(22):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertEqual(
......@@ -187,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
with self.assertNumQueries(26):
with self.assertNumQueries(24):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(
......@@ -203,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(39):
with self.assertNumQueries(37):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
def test_list_key_filter(self):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
courses = CourseFactory.create_batch(3, partner=self.partner)
courses = CourseFactory.create_batch(3)
courses = sorted(courses, key=lambda course: course.key.lower())
keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(39):
with self.assertNumQueries(37):
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
import uuid
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import Organization, OrganizationFactory
......@@ -13,7 +14,6 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def setUp(self):
super(OrganizationViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
def test_authentication(self):
......@@ -27,17 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def assert_response_data_valid(self, response, organizations, many=True):
""" Asserts the response data (only) contains the expected organizations. """
actual = response.data
serializer_data = self.serialize_organization(organizations, many=many)
if many:
actual = actual['results']
self.assertCountEqual(actual, serializer_data)
self.assertEqual(actual, self.serialize_organization(organizations, many=many))
def assert_list_uuid_filter(self, organizations, expected_query_count):
def assert_list_uuid_filter(self, organizations):
""" Asserts the list endpoint supports filtering by UUID. """
organizations = sorted(organizations, key=lambda o: o.created)
with self.assertNumQueries(expected_query_count):
with self.assertNumQueries(5):
uuids = ','.join([organization.uuid.hex for organization in organizations])
url = '{root}?uuids={uuids}'.format(root=self.list_path, uuids=uuids)
response = self.client.get(url)
......@@ -47,6 +47,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def assert_list_tag_filter(self, organizations, tags, expected_query_count=5):
""" Asserts the list endpoint supports filtering by tags. """
with self.assertNumQueries(expected_query_count):
tags = ','.join(tags)
url = '{root}?tags={tags}'.format(root=self.list_path, tags=tags)
......@@ -57,9 +58,10 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list(self):
""" Verify the endpoint returns a list of all organizations. """
OrganizationFactory.create_batch(3, partner=self.partner)
with self.assertNumQueries(7):
OrganizationFactory.create_batch(3)
with self.assertNumQueries(5):
response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 200)
......@@ -68,22 +70,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list_uuid_filter(self):
""" Verify the endpoint returns a list of organizations filtered by UUID. """
organizations = OrganizationFactory.create_batch(3, partner=self.partner)
organizations = OrganizationFactory.create_batch(3)
# Test with a single UUID
self.assert_list_uuid_filter([organizations[0]], 7)
self.assert_list_uuid_filter([organizations[0]])
# Test with multiple UUIDs
self.assert_list_uuid_filter(organizations, 5)
self.assert_list_uuid_filter(organizations)
def test_list_tag_filter(self):
""" Verify the endpoint returns a list of organizations filtered by tag. """
tag = 'test-org'
organizations = OrganizationFactory.create_batch(2, partner=self.partner)
organizations = OrganizationFactory.create_batch(2)
# If no organizations have been tagged, the endpoint should not return any data
self.assert_list_tag_filter([], [tag], expected_query_count=5)
self.assert_list_tag_filter([], [tag], expected_query_count=3)
# Tagged organizations should be returned
organizations[0].tags.add(tag)
......@@ -97,7 +99,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns details for a single organization. """
organization = OrganizationFactory(partner=self.partner)
organization = OrganizationFactory()
url = reverse('api:v1:organization-detail', kwargs={'uuid': organization.uuid})
response = self.client.get(url)
......
# pylint: disable=redefined-builtin,no-member
import ddt
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import IntegrityError
from mock import mock
......@@ -7,20 +8,20 @@ from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.views.people import logger as people_logger
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.models import Person
from course_discovery.apps.course_metadata.people import MarketingSitePeople
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory, PositionFactory
from course_discovery.apps.course_metadata.tests.factories import (OrganizationFactory, PartnerFactory, PersonFactory,
PositionFactory)
User = get_user_model()
@ddt.ddt
class PersonViewSetTests(SerializationMixin, SiteMixin, APITestCase):
class PersonViewSetTests(SerializationMixin, APITestCase):
""" Tests for the person resource. """
people_list_url = reverse('api:v1:person-list')
......@@ -31,6 +32,10 @@ class PersonViewSetTests(SerializationMixin, SiteMixin, APITestCase):
self.person = PersonFactory()
PositionFactory(person=self.person)
self.organization = OrganizationFactory()
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
toggle_switch('publish_person_to_marketing_site', True)
self.expected_node = {
'resource': 'node', ''
......
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.models import ProgramType
from course_discovery.apps.course_metadata.tests.factories import ProgramTypeFactory
......@@ -27,7 +28,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all program types. """
ProgramTypeFactory.create_batch(4)
expected = ProgramType.objects.all()
with self.assertNumQueries(6):
with self.assertNumQueries(5):
response = self.client.get(self.list_path)
assert response.status_code == 200
......@@ -38,7 +39,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
program_type = ProgramTypeFactory()
url = reverse('api:v1:program_type-detail', kwargs={'slug': program_type.slug})
with self.assertNumQueries(5):
with self.assertNumQueries(4):
response = self.client.get(url)
assert response.status_code == 200
......
......@@ -3,9 +3,10 @@ import urllib.parse
import ddt
from django.core.cache import cache
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import MinimalProgramSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.views.programs import ProgramViewSet
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
......@@ -24,17 +25,16 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def setUp(self):
super(ProgramViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD)
# Clear the cache between test cases, so they don't interfere with each other.
cache.clear()
def create_program(self):
organizations = [OrganizationFactory(partner=self.partner)]
organizations = [OrganizationFactory()]
person = PersonFactory()
course = CourseFactory(partner=self.partner)
course = CourseFactory()
CourseRunFactory(course=course, staff=[person])
program = ProgramFactory(
......@@ -46,8 +46,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
expected_learning_items=ExpectedLearningItemFactory.create_batch(1),
job_outlook_items=JobOutlookItemFactory.create_batch(1),
banner_image=make_image_file('test_banner.jpg'),
video=VideoFactory(),
partner=self.partner
video=VideoFactory()
)
return program
......@@ -74,7 +73,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self):
""" Verify the endpoint returns the details for a single program. """
program = self.create_program()
with self.assertNumQueries(39):
with self.assertNumQueries(37):
response = self.assert_retrieve_success(program)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
......@@ -91,25 +90,22 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
@ddt.data(True, False)
def test_retrieve_with_sorting_flag(self, order_courses_by_start_date):
""" Verify the number of queries is the same with sorting flag set to true. """
course_list = CourseFactory.create_batch(3, partner=self.partner)
course_list = CourseFactory.create_batch(3)
for course in course_list:
CourseRunFactory(course=course)
program = ProgramFactory(
courses=course_list,
order_courses_by_start_date=order_courses_by_start_date,
partner=self.partner)
program = ProgramFactory(courses=course_list, order_courses_by_start_date=order_courses_by_start_date)
# property does not have the right values while being indexed
del program._course_run_weeks_to_complete
with self.assertNumQueries(28):
with self.assertNumQueries(26):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
self.assertEqual(course_list, list(program.courses.all())) # pylint: disable=no-member
def test_retrieve_without_course_runs(self):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course = CourseFactory(partner=self.partner)
program = ProgramFactory(courses=[course], partner=self.partner)
with self.assertNumQueries(22):
course = CourseFactory()
program = ProgramFactory(courses=[course])
with self.assertNumQueries(20):
response = self.assert_retrieve_success(program)
assert response.data == self.serialize_program(program)
......@@ -139,7 +135,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all programs. """
expected = [self.create_program() for __ in range(3)]
expected.reverse()
self.assert_list_results(self.list_path, expected, 14)
self.assert_list_results(self.list_path, expected, 12)
# Verify that repeated list requests use the cache.
self.assert_list_results(self.list_path, expected, 2)
......@@ -149,8 +145,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
Verify that the list view returns a simply list of UUIDs when the
uuids_only query parameter is passed.
"""
active = ProgramFactory.create_batch(3, partner=self.partner)
retired = [ProgramFactory(status=ProgramStatus.Retired, partner=self.partner)]
active = ProgramFactory.create_batch(3)
retired = [ProgramFactory(status=ProgramStatus.Retired)]
programs = active + retired
querystring = {'uuids_only': 1}
......@@ -169,47 +165,47 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name, partner=self.partner)
program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program], 10)
self.assert_list_results(url, [program], 8)
url = self.list_path + '?type=bar'
self.assert_list_results(url, [], 3)
def test_filter_by_types(self):
""" Verify that the endpoint filters programs to those matching the provided ProgramType slugs. """
expected = ProgramFactory.create_batch(2, partner=self.partner)
expected = ProgramFactory.create_batch(2)
expected.reverse()
type_slugs = [p.type.slug for p in expected]
url = self.list_path + '?types=' + ','.join(type_slugs)
# Create a third program, which should be filtered out.
ProgramFactory(partner=self.partner)
ProgramFactory()
self.assert_list_results(url, expected, 10)
self.assert_list_results(url, expected, 8)
def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
expected = ProgramFactory.create_batch(2, partner=self.partner)
expected = ProgramFactory.create_batch(2)
expected.reverse()
uuids = [str(p.uuid) for p in expected]
url = self.list_path + '?uuids=' + ','.join(uuids)
# Create a third program, which should be filtered out.
ProgramFactory(partner=self.partner)
ProgramFactory()
self.assert_list_results(url, expected, 10)
self.assert_list_results(url, expected, 8)
@ddt.data(
(ProgramStatus.Unpublished, False, 5),
(ProgramStatus.Active, True, 10),
(ProgramStatus.Unpublished, False, 3),
(ProgramStatus.Active, True, 8),
)
@ddt.unpack
def test_filter_by_marketable(self, status, is_marketable, expected_query_count):
""" Verify the endpoint filters programs to those that are marketable. """
url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='', partner=self.partner)
programs = ProgramFactory.create_batch(3, status=status, partner=self.partner)
ProgramFactory(marketing_slug='')
programs = ProgramFactory.create_batch(3, status=status)
programs.reverse()
expected = programs if is_marketable else []
......@@ -218,11 +214,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_status(self):
""" Verify the endpoint allows programs to filtered by one, or more, statuses. """
active = ProgramFactory(status=ProgramStatus.Active, partner=self.partner)
retired = ProgramFactory(status=ProgramStatus.Retired, partner=self.partner)
active = ProgramFactory(status=ProgramStatus.Active)
retired = ProgramFactory(status=ProgramStatus.Retired)
url = self.list_path + '?status=active'
self.assert_list_results(url, [active], 10)
self.assert_list_results(url, [active], 8)
url = self.list_path + '?status=retired'
self.assert_list_results(url, [retired], 8)
......@@ -232,11 +228,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def test_filter_by_hidden(self):
""" Endpoint should filter programs by their hidden attribute value. """
hidden = ProgramFactory(hidden=True, partner=self.partner)
not_hidden = ProgramFactory(hidden=False, partner=self.partner)
hidden = ProgramFactory(hidden=True)
not_hidden = ProgramFactory(hidden=False)
url = self.list_path + '?hidden=True'
self.assert_list_results(url, [hidden], 10)
self.assert_list_results(url, [hidden], 8)
url = self.list_path + '?hidden=False'
self.assert_list_results(url, [not_hidden], 8)
......@@ -251,7 +247,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns marketing URLs without UTM parameters. """
url = self.list_path + '?exclude_utm=1'
program = self.create_program()
self.assert_list_results(url, [program], 14, extra_context={'exclude_utm': 1})
self.assert_list_results(url, [program], 12, extra_context={'exclude_utm': 1})
def test_minimal_serializer_use(self):
""" Verify that the list view uses the minimal serializer. """
......
......@@ -3,12 +3,13 @@ import json
import urllib.parse
import ddt
from django.conf import settings
from django.urls import reverse
from haystack.query import SearchQuerySet
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer)
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.views.search import TypeaheadSearchView
from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
......@@ -87,8 +88,14 @@ class SynonymTestMixin:
self.assertDictEqual(response1, response2)
class DefaultPartnerMixin:
def setUp(self):
super(DefaultPartnerMixin, self).setUp()
self.partner = PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
@ddt.ddt
class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase):
""" Tests for CourseRunSearchViewSet. """
faceted_path = reverse('api:v1:search-course_runs-facets')
......@@ -155,9 +162,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
return course_run, response_data
def build_facet_url(self, params):
return 'http://testserver.fake{path}?{query}'.format(
path=self.faceted_path, query=urllib.parse.urlencode(params)
)
return 'http://testserver{path}?{query}'.format(path=self.faceted_path, query=urllib.parse.urlencode(params))
def test_invalid_query_facet(self):
""" Verify the endpoint returns HTTP 400 if an invalid facet is requested. """
......@@ -266,7 +271,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
)
self.reindex_courses(program)
with self.assertNumQueries(5):
with self.assertNumQueries(4):
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
......@@ -290,7 +295,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
ProgramFactory(courses=[course_run.course], status=program_status)
self.reindex_courses(active_program)
with self.assertNumQueries(6):
with self.assertNumQueries(5):
response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200)
......@@ -308,7 +313,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
@ddt.ddt
class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets')
......@@ -433,7 +438,7 @@ class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
assert expected == actual
class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-typeahead')
......@@ -615,3 +620,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
self.serialize_program(harvard_program)]
}
self.assertDictEqual(response.data, expected)
def test_typeahead_partner_filter(self):
""" Ensure that a partner param limits results to that partner. """
course_runs = []
programs = []
for partner in ['edx', 'other']:
title = 'Belongs to partner ' + partner
partner = PartnerFactory(short_code=partner)
course_runs.append(CourseRunFactory(title=title, course=CourseFactory(partner=partner)))
programs.append(ProgramFactory(
title=title, partner=partner,
status=ProgramStatus.Active
))
response = self.get_response({'q': 'partner'}, 'edx')
self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0]
edx_program = programs[0]
self.assertDictEqual(response.data, {'course_runs': [self.serialize_course_run(edx_course_run)],
'programs': [self.serialize_program(edx_program)]})
......@@ -35,3 +35,18 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
......@@ -9,13 +9,14 @@ from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import ProxiedPagination
from course_discovery.apps.api.utils import get_query_param
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member
class CourseRunViewSet(viewsets.ModelViewSet):
class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
""" CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter
......@@ -42,7 +43,7 @@ class CourseRunViewSet(viewsets.ModelViewSet):
multiple: false
"""
q = self.request.query_params.get('q')
partner = self.request.site.partner
partner = self.get_partner()
if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
......@@ -79,6 +80,12 @@ class CourseRunViewSet(viewsets.ModelViewSet):
type: string
paramType: query
multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
- name: hidden
description: Filter based on wether the course run is hidden from search.
required: false
......@@ -159,7 +166,7 @@ class CourseRunViewSet(viewsets.ModelViewSet):
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids')
partner = self.request.site.partner
partner = self.get_partner()
if query and course_run_ids:
course_run_ids = course_run_ids.split(',')
......
......@@ -18,6 +18,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
filter_class = filters.CourseFilter
lookup_field = 'key'
lookup_value_regex = COURSE_ID_REGEX
queryset = Course.objects.all()
permission_classes = (IsAuthenticated,)
serializer_class = serializers.CourseWithProgramsSerializer
......@@ -26,17 +27,16 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class = ProxiedPagination
def get_queryset(self):
partner = self.request.site.partner
q = self.request.query_params.get('q')
if q:
queryset = Course.search(q)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset, partner=partner)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset)
else:
if get_query_param(self.request, 'include_hidden_course_runs'):
course_runs = CourseRun.objects.filter(course__partner=partner)
course_runs = CourseRun.objects.all()
else:
course_runs = CourseRun.objects.filter(course__partner=partner).exclude(hidden=True)
course_runs = CourseRun.objects.exclude(hidden=True)
if get_query_param(self.request, 'marketable_course_runs_only'):
course_runs = course_runs.marketable().active()
......@@ -49,8 +49,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
queryset = self.get_serializer_class().prefetch_queryset(
queryset=self.queryset,
course_runs=course_runs,
partner=partner
course_runs=course_runs
)
return queryset.order_by(Lower('key'))
......
......@@ -15,16 +15,13 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+'
permission_classes = (IsAuthenticated,)
queryset = serializers.OrganizationSerializer.prefetch_queryset()
serializer_class = serializers.OrganizationSerializer
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# versions of this API should only support the system default, PageNumberPagination.
pagination_class = ProxiedPagination
def get_queryset(self):
partner = self.request.site.partner
return serializers.OrganizationSerializer.prefetch_queryset(partner=partner)
def list(self, request, *args, **kwargs):
""" Retrieve a list of all organizations. """
return super(OrganizationViewSet, self).list(request, *args, **kwargs)
......
......@@ -7,6 +7,7 @@ from rest_framework.response import Response
from course_discovery.apps.api import serializers
from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.exceptions import MarketingSiteAPIClientException, PersonToMarketingException
from course_discovery.apps.course_metadata.people import MarketingSitePeople
......@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=no-member
class PersonViewSet(viewsets.ModelViewSet):
class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
""" PersonSerializer resource. """
lookup_field = 'uuid'
......@@ -29,7 +30,7 @@ class PersonViewSet(viewsets.ModelViewSet):
""" Create a new person. """
person_data = request.data
partner = request.site.partner
partner = self.get_partner()
person_data['partner'] = partner.id
serializer = self.get_serializer(data=person_data)
serializer.is_valid(raise_exception=True)
......
......@@ -32,8 +32,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
partner = self.request.site.partner
return self.get_serializer_class().prefetch_queryset(partner)
return self.get_serializer_class().prefetch_queryset()
def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs)
......@@ -90,7 +89,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
if get_query_param(self.request, 'uuids_only'):
# DRF serializers don't have good support for simple, flat
# representations like the one we want here.
queryset = self.filter_queryset(Program.objects.filter(partner=self.request.site.partner))
queryset = self.filter_queryset(Program.objects.all())
uuids = queryset.values_list('uuid', flat=True)
return Response(uuids)
......
......@@ -12,6 +12,7 @@ from rest_framework.response import Response
from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
......@@ -118,7 +119,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(APIView):
class TypeaheadSearchView(PartnerMixin, APIView):
""" Typeahead for courses and programs. """
RESULT_COUNT = 3
permission_classes = (IsAuthenticated,)
......@@ -180,7 +181,7 @@ class TypeaheadSearchView(APIView):
type: string
"""
query = request.query_params.get('q')
partner = request.site.partner
partner = self.get_partner()
if not query:
raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query, partner)
......
......@@ -3,11 +3,10 @@ import json
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
class UserAutocompleteTests(SiteMixin, TestCase):
class UserAutocompleteTests(TestCase):
""" Tests for user autocomplete lookups."""
def setUp(self):
......
......@@ -3,13 +3,12 @@ from django.core.cache import cache
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import UserThrottleRate
from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory
from course_discovery.apps.core.throttles import OverridableUserRateThrottle
class RateLimitingTest(SiteMixin, APITestCase):
class RateLimitingTest(APITestCase):
"""
Testing rate limiting of API calls.
"""
......
......@@ -9,20 +9,18 @@ from django.test.utils import override_settings
from django.urls import reverse
from django.utils.encoding import force_text
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.constants import Status
User = get_user_model()
class HealthTests(SiteMixin, TestCase):
class HealthTests(TestCase):
"""Tests of the health endpoint."""
def test_all_services_available(self):
"""Test that the endpoint reports when all services are healthy."""
self._assert_health(200, Status.OK, Status.OK)
@mock.patch('django.contrib.sites.middleware.get_current_site', mock.Mock(return_value=None))
def test_database_outage(self):
"""Test that the endpoint reports when the database is unavailable."""
with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError):
......@@ -44,7 +42,7 @@ class HealthTests(SiteMixin, TestCase):
self.assertJSONEqual(force_text(response.content), expected_data)
class AutoAuthTests(SiteMixin, TestCase):
class AutoAuthTests(TestCase):
""" Auto Auth view tests. """
AUTO_AUTH_PATH = reverse('auto_auth')
......
......@@ -11,7 +11,6 @@ from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import Select
from selenium.webdriver.support.wait import WebDriverWait
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
......@@ -24,7 +23,7 @@ from course_discovery.apps.course_metadata.tests import factories
# pylint: disable=no-member
@ddt.ddt
class AdminTests(SiteMixin, TestCase):
class AdminTests(TestCase):
""" Tests Admin page."""
def setUp(self):
......@@ -191,7 +190,7 @@ class AdminTests(SiteMixin, TestCase):
self.assertEqual(response.status_code, 200)
class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
class ProgramAdminFunctionalTests(LiveServerTestCase):
""" Functional Tests for Admin page."""
# Required for access to initial data loaded in migrations (e.g., LanguageTags).
serialized_rollback = True
......@@ -225,6 +224,7 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
def setUp(self):
super().setUp()
# ContentTypeManager uses a cache to speed up ContentType retrieval. This
# cache persists across tests. This is fine in the context of a regular
# TestCase which uses a transaction to reset the database between tests.
......@@ -238,9 +238,6 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
# stale ContentType objects from being used.
ContentType.objects.clear_cache()
self.site.domain = self.live_server_url.strip('http://')
self.site.save()
self.course_runs = factories.CourseRunFactory.create_batch(2)
self.courses = [course_run.course for course_run in self.course_runs]
......@@ -352,7 +349,7 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
self.assertEqual(self.program.subtitle, subtitle)
class ProgramEligibilityFilterTests(SiteMixin, TestCase):
class ProgramEligibilityFilterTests(TestCase):
""" Tests for Program Eligibility Filter class. """
parameter_name = 'eligible_for_one_click_purchase'
......
......@@ -5,7 +5,6 @@ import ddt
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory
......@@ -17,7 +16,7 @@ from course_discovery.apps.publisher.tests import factories
@ddt.ddt
class AutocompleteTests(SiteMixin, TestCase):
class AutocompleteTests(TestCase):
""" Tests for autocomplete lookups."""
def setUp(self):
super(AutocompleteTests, self).setUp()
......@@ -119,7 +118,7 @@ class AutocompleteTests(SiteMixin, TestCase):
@ddt.ddt
class AutoCompletePersonTests(SiteMixin, TestCase):
class AutoCompletePersonTests(TestCase):
"""
Tests for person autocomplete lookups
"""
......
......@@ -2,17 +2,17 @@ import datetime
import urllib.parse
from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.tests.test_views.test_search import (
ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin
DefaultPartnerMixin, ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin
)
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory, ProgramFactory
from course_discovery.apps.edx_catalog_extensions.api.serializers import DistinctCountsAggregateFacetSearchSerializer
class DistinctCountsAggregateSearchViewSetTests(SerializationMixin, LoginMixin,
class DistinctCountsAggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin,
ElasticsearchTestMixin, SynonymTestMixin, APITestCase):
path = reverse('extensions:api:v1:search-all-facets')
......
......@@ -4,14 +4,13 @@ import ddt
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint: disable=no-member
@ddt.ddt
class AutocompleteTests(SiteMixin, TestCase):
class AutocompleteTests(TestCase):
""" Tests for autocomplete lookups."""
def setUp(self):
super(AutocompleteTests, self).setUp()
......
......@@ -40,8 +40,7 @@ class CourseUserRoleSerializer(serializers.ModelSerializer):
former_user = instance.user
instance = super(CourseUserRoleSerializer, self).update(instance, validated_data)
if not instance.role == PublisherUserRole.CourseTeam:
request = self.context['request']
send_change_role_assignment_email(instance, former_user, request.site)
send_change_role_assignment_email(instance, former_user)
return instance
......@@ -105,7 +104,6 @@ class CourseRunSerializer(serializers.ModelSerializer):
instance = super(CourseRunSerializer, self).update(instance, validated_data)
preview_url = validated_data.get('preview_url')
lms_course_id = validated_data.get('lms_course_id')
request = self.context['request']
if preview_url:
# Change ownership to CourseTeam.
......@@ -113,10 +111,10 @@ class CourseRunSerializer(serializers.ModelSerializer):
if waffle.switch_is_active('enable_publisher_email_notifications'):
if preview_url:
send_email_preview_page_is_available(instance, site=request.site)
send_email_preview_page_is_available(instance)
elif lms_course_id:
send_email_for_studio_instance_created(instance, site=request.site)
send_email_for_studio_instance_created(instance)
return instance
......@@ -169,7 +167,7 @@ class CourseStateSerializer(serializers.ModelSerializer):
state = validated_data.get('name')
request = self.context.get('request')
try:
instance.change_state(state=state, user=request.user, site=request.site)
instance.change_state(state=state, user=request.user)
except TransitionNotAllowed:
# pylint: disable=no-member
raise serializers.ValidationError(
......@@ -206,7 +204,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
if state:
try:
instance.change_state(state=state, user=request.user, site=request.site)
instance.change_state(state=state, user=request.user)
except TransitionNotAllowed:
# pylint: disable=no-member
raise serializers.ValidationError(
......@@ -225,6 +223,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
instance.save()
if waffle.switch_is_active('enable_publisher_email_notifications'):
send_email_preview_accepted(instance.course_run, request.site)
send_email_preview_accepted(instance.course_run)
return instance
......@@ -5,7 +5,6 @@ from django.test import RequestFactory, TestCase
from opaque_keys.edx.keys import CourseKey
from rest_framework.exceptions import ValidationError
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch
......@@ -21,7 +20,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour
OrganizationExtensionFactory, SeatFactory)
class CourseUserRoleSerializerTests(SiteMixin, TestCase):
class CourseUserRoleSerializerTests(TestCase):
serializer_class = CourseUserRoleSerializer
def setUp(self):
......@@ -29,7 +28,6 @@ class CourseUserRoleSerializerTests(SiteMixin, TestCase):
self.request = RequestFactory()
self.course_user_role = CourseUserRoleFactory(role=PublisherUserRole.MarketingReviewer)
self.request.user = self.course_user_role.user
self.request.site = self.site
def get_expected_data(self):
""" Helper method which will return expected serialize data. """
......@@ -140,7 +138,7 @@ class CourseRunSerializerTests(TestCase):
"""
self.course_run.preview_url = ''
self.course_run.save()
serializer = self.serializer_class(self.course_run, context={'request': self.request})
serializer = self.serializer_class(self.course_run)
serializer.update(self.course_run, {'preview_url': 'https://example.com/abc/course'})
self.assertEqual(self.course_state.owner_role, PublisherUserRole.CourseTeam)
......@@ -248,12 +246,13 @@ class CourseRevisionSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected)
class CourseStateSerializerTests(SiteMixin, TestCase):
class CourseStateSerializerTests(TestCase):
serializer_class = CourseStateSerializer
def setUp(self):
super(CourseStateSerializerTests, self).setUp()
self.course_state = CourseStateFactory(name=CourseStateChoices.Draft)
self.request = RequestFactory()
self.user = UserFactory()
self.request.user = self.user
......@@ -290,13 +289,14 @@ class CourseStateSerializerTests(SiteMixin, TestCase):
serializer.update(self.course_state, data)
class CourseRunStateSerializerTests(SiteMixin, TestCase):
class CourseRunStateSerializerTests(TestCase):
serializer_class = CourseRunStateSerializer
def setUp(self):
super(CourseRunStateSerializerTests, self).setUp()
self.run_state = CourseRunStateFactory(name=CourseRunStateChoices.Draft)
self.course_run = self.run_state.course_run
self.request = RequestFactory()
self.user = UserFactory()
self.request.user = self.user
CourseStateFactory(name=CourseStateChoices.Approved, course=self.course_run.course)
......
......@@ -4,6 +4,7 @@ from urllib.parse import quote
import ddt
from django.contrib.auth.models import Group
from django.contrib.sites.models import Site
from django.core import mail
from django.db import IntegrityError
from django.test import TestCase
......@@ -13,7 +14,6 @@ from mock import mock, patch
from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch
......@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories
@ddt.ddt
class CourseRoleAssignmentViewTests(SiteMixin, TestCase):
class CourseRoleAssignmentViewTests(TestCase):
def setUp(self):
super(CourseRoleAssignmentViewTests, self).setUp()
......@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(SiteMixin, TestCase):
self.assertEqual(len(mail.outbox), 1)
class OrganizationGroupUserViewTests(SiteMixin, TestCase):
class OrganizationGroupUserViewTests(TestCase):
def setUp(self):
super(OrganizationGroupUserViewTests, self).setUp()
......@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(SiteMixin, TestCase):
)
class UpdateCourseRunViewTests(SiteMixin, TestCase):
class UpdateCourseRunViewTests(TestCase):
def setUp(self):
super(UpdateCourseRunViewTests, self).setUp()
......@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
def test_update_preview_url(self):
......@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(SiteMixin, TestCase):
self.assertEqual(len(mail.outbox), 0)
class CourseRevisionDetailViewTests(SiteMixin, TestCase):
class CourseRevisionDetailViewTests(TestCase):
def setUp(self):
super(CourseRevisionDetailViewTests, self).setUp()
......@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(SiteMixin, TestCase):
return self.client.get(path=course_revision_path)
class ChangeCourseStateViewTests(SiteMixin, TestCase):
class ChangeCourseStateViewTests(TestCase):
def setUp(self):
super(ChangeCourseStateViewTests, self).setUp()
......@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip()
object_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
def test_change_course_state_with_error(self):
......@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(SiteMixin, TestCase):
self._assert_email_sent(course_team_user, subject)
class ChangeCourseRunStateViewTests(SiteMixin, TestCase):
class ChangeCourseRunStateViewTests(TestCase):
def setUp(self):
super(ChangeCourseRunStateViewTests, self).setUp()
......@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(SiteMixin, TestCase):
self.assertIn('has been published', mail.outbox[0].body.strip())
class RevertCourseByRevisionTests(SiteMixin, TestCase):
class RevertCourseByRevisionTests(TestCase):
def setUp(self):
super(RevertCourseByRevisionTests, self).setUp()
......@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(SiteMixin, TestCase):
return self.client.put(path=course_revision_path)
class CoursesAutoCompleteTests(SiteMixin, TestCase):
class CoursesAutoCompleteTests(TestCase):
""" Tests for course autocomplete."""
def setUp(self):
......@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(SiteMixin, TestCase):
self.assertEqual(len(data['results']), expected_length)
class AcceptAllByRevisionTests(SiteMixin, TestCase):
class AcceptAllByRevisionTests(TestCase):
def setUp(self):
super(AcceptAllByRevisionTests, self).setUp()
......
import logging
from django.conf import settings
from django.contrib.sites.models import Site
from django.core.mail.message import EmailMultiAlternatives
from django.template.loader import get_template
from django.urls import reverse
......@@ -15,12 +16,11 @@ from course_discovery.apps.publisher.utils import is_email_notification_enabled
logger = logging.getLogger(__name__)
def send_email_for_studio_instance_created(course_run, site):
def send_email_for_studio_instance_created(course_run):
""" Send an email to course team on studio instance creation.
Arguments:
course_run (CourseRun): CourseRun object
site (Site): Current site
"""
try:
course_key = CourseKey.from_string(course_run.lms_course_id)
......@@ -39,7 +39,7 @@ def send_email_for_studio_instance_created(course_run, site):
context = {
'course_run': course_run,
'course_run_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=object_path
host=Site.objects.get_current().domain.strip('/'), path=object_path
),
'course_name': course_run.course.title,
'from_address': from_address,
......@@ -65,13 +65,12 @@ def send_email_for_studio_instance_created(course_run, site):
raise Exception(error_message)
def send_email_for_course_creation(course, course_run, site):
def send_email_for_course_creation(course, course_run):
""" Send the emails for a course creation.
Arguments:
course (Course): Course object
course_run (CourseRun): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_created.txt'
html_template = 'publisher/email/course_created.html'
......@@ -92,7 +91,7 @@ def send_email_for_course_creation(course, course_run, site):
'course_team_name': course_team.get_full_name(),
'project_coordinator_name': project_coordinator.get_full_name(),
'dashboard_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=reverse('publisher:publisher_dashboard')
host=Site.objects.get_current().domain.strip('/'), path=reverse('publisher:publisher_dashboard')
),
'from_address': from_address,
'contact_us_email': project_coordinator.email
......@@ -114,13 +113,12 @@ def send_email_for_course_creation(course, course_run, site):
)
def send_email_for_send_for_review(course, user, site):
def send_email_for_send_for_review(course, user):
""" Send email when course is submitted for review.
Arguments:
course (Object): Course object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course/send_for_review.txt'
html_template = 'publisher/email/course/send_for_review.html'
......@@ -137,22 +135,21 @@ def send_email_for_send_for_review(course, user, site):
'course_name': course.title,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
)
}
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site)
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications send for review of course %s', course.id)
def send_email_for_mark_as_reviewed(course, user, site):
def send_email_for_mark_as_reviewed(course, user):
""" Send email when course is marked as reviewed.
Arguments:
course (Object): Course object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course/mark_as_reviewed.txt'
html_template = 'publisher/email/course/mark_as_reviewed.html'
......@@ -169,16 +166,16 @@ def send_email_for_mark_as_reviewed(course, user, site):
'course_name': course.title,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
)
}
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site)
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications mark as reviewed of course %s', course.id)
def send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site):
def send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user):
""" Send email for course workflow state change.
Arguments:
......@@ -189,7 +186,6 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
html_template: (String): Email html template path
context: (Dict): Email template context
recipient_user: (Object): User object
site (Site): Current site
"""
if is_email_notification_enabled(recipient_user):
......@@ -206,7 +202,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
)
}
)
......@@ -223,13 +219,12 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
email_msg.send()
def send_email_for_send_for_review_course_run(course_run, user, site):
def send_email_for_send_for_review_course_run(course_run, user):
""" Send email when course-run is submitted for review.
Arguments:
course-run (Object): CourseRun object
user (Object): User object
site (Site): Current site
"""
course = course_run.course
course_key = CourseKey.from_string(course_run.lms_course_id)
......@@ -251,23 +246,22 @@ def send_email_for_send_for_review_course_run(course_run, user, site):
'run_number': course_key.run,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'project coordinators',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
),
'studio_url': course_run.studio_url
}
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site)
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications send for review of course-run %s', course_run.id)
def send_email_for_mark_as_reviewed_course_run(course_run, user, site):
def send_email_for_mark_as_reviewed_course_run(course_run, user):
""" Send email when course-run is marked as reviewed.
Arguments:
course_run (Object): CourseRun object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/mark_as_reviewed.txt'
html_template = 'publisher/email/course_run/mark_as_reviewed.html'
......@@ -290,24 +284,21 @@ def send_email_for_mark_as_reviewed_course_run(course_run, user, site):
'run_number': course_key.run,
'sender_team': 'course team',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
)
}
send_course_workflow_email(
course, user, subject, txt_template, html_template, context, recipient_user, site
)
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id)
def send_email_to_publisher(course_run, user, site):
def send_email_to_publisher(course_run, user):
""" Send email to publisher when course-run is marked as reviewed.
Arguments:
course_run (Object): CourseRun object
user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/mark_as_reviewed.txt'
html_template = 'publisher/email/course_run/mark_as_reviewed.html'
......@@ -339,7 +330,7 @@ def send_email_to_publisher(course_run, user, site):
'sender_team': sender_team,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
)
}
......@@ -358,12 +349,11 @@ def send_email_to_publisher(course_run, user, site):
logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id)
def send_email_preview_accepted(course_run, site):
def send_email_preview_accepted(course_run):
""" Send email for preview approved to publisher and project coordinator.
Arguments:
course_run (Object): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/preview_accepted.txt'
html_template = 'publisher/email/course_run/preview_accepted.html'
......@@ -392,10 +382,10 @@ def send_email_preview_accepted(course_run, site):
'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
),
'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
)
}
template = get_template(txt_template)
......@@ -416,12 +406,11 @@ def send_email_preview_accepted(course_run, site):
raise Exception(message)
def send_email_preview_page_is_available(course_run, site):
def send_email_preview_page_is_available(course_run):
""" Send email for course preview available to course team.
Arguments:
course_run (Object): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/preview_available.txt'
html_template = 'publisher/email/course_run/preview_available.html'
......@@ -447,10 +436,10 @@ def send_email_preview_page_is_available(course_run, site):
'preview_link': course_run.preview_url,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
),
'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
),
'platform_name': settings.PLATFORM_NAME
}
......@@ -473,12 +462,11 @@ def send_email_preview_page_is_available(course_run, site):
raise Exception(error_message)
def send_course_run_published_email(course_run, site):
def send_course_run_published_email(course_run):
""" Send email when course run is published by publisher.
Arguments:
course_run (Object): CourseRun object
site (Site): Current site
"""
txt_template = 'publisher/email/course_run/published.txt'
html_template = 'publisher/email/course_run/published.html'
......@@ -504,10 +492,10 @@ def send_course_run_published_email(course_run, site):
'recipient_name': course_team_user.get_full_name() or course_team_user.username,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
),
'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
),
'platform_name': settings.PLATFORM_NAME,
}
......@@ -530,13 +518,12 @@ def send_course_run_published_email(course_run, site):
raise Exception(error_message)
def send_change_role_assignment_email(course_role, former_user, site):
def send_change_role_assignment_email(course_role, former_user):
""" Send email for role assignment changed.
Arguments:
course_role (Object): CourseUserRole object
former_user (Object): User object
site (Site): Current site
"""
txt_template = 'publisher/email/role_assignment_changed.txt'
html_template = 'publisher/email/role_assignment_changed.html'
......@@ -562,7 +549,7 @@ def send_change_role_assignment_email(course_role, former_user, site):
'current_user_name': course_role.user.get_full_name() or course_role.user.username,
'contact_us_email': project_coordinator.email if project_coordinator else '',
'course_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path
host=Site.objects.get_current().domain.strip('/'), path=page_path
),
'platform_name': settings.PLATFORM_NAME,
}
......@@ -585,12 +572,11 @@ def send_change_role_assignment_email(course_role, former_user, site):
raise Exception(error_message)
def send_email_for_seo_review(course, site):
def send_email_for_seo_review(course):
""" Send email when course is submitted for seo review.
Arguments:
course (Object): Course object
site (Site): Current site
"""
txt_template = 'publisher/email/course/seo_review.txt'
html_template = 'publisher/email/course/seo_review.html'
......@@ -611,7 +597,7 @@ def send_email_for_seo_review(course, site):
'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email,
'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path
host=Site.objects.get_current().domain.strip('/'), path=course_page_path
)
}
......@@ -629,12 +615,11 @@ def send_email_for_seo_review(course, site):
logger.exception('Failed to send email notifications for legal review requested of course %s', course.id)
def send_email_for_published_course_run_editing(course_run, site):
def send_email_for_published_course_run_editing(course_run):
""" Send email when published course-run is edited.
Arguments:
course-run (Object): Course Run object
site (Site): Current site
"""
try:
course = course_run.course
......@@ -659,7 +644,7 @@ def send_email_for_published_course_run_editing(course_run, site):
'recipient_name': publisher_user.get_full_name() or publisher_user.username,
'contact_us_email': course.project_coordinator.email,
'course_run_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=object_path
host=Site.objects.get_current().domain.strip('/'), path=object_path
),
'course_run_number': course_key.run,
}
......
......@@ -617,7 +617,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
# TODO: send email etc.
pass
def change_state(self, state, user, site=None):
def change_state(self, state, user):
"""
Change course workflow state and ownership also send emails if required.
"""
......@@ -632,12 +632,12 @@ class CourseState(TimeStampedModel, ChangedByMixin):
elif user_role.role == PublisherUserRole.CourseTeam:
self.change_owner_role(PublisherUserRole.MarketingReviewer)
if is_notifications_enabled:
emails.send_email_for_seo_review(self.course, site)
emails.send_email_for_seo_review(self.course)
self.review()
if is_notifications_enabled:
emails.send_email_for_send_for_review(self.course, user, site)
emails.send_email_for_send_for_review(self.course, user)
elif state == CourseStateChoices.Approved:
user_role = self.course.course_user_roles.get(user=user)
......@@ -646,7 +646,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
self.approved()
if is_notifications_enabled:
emails.send_email_for_mark_as_reviewed(self.course, user, site)
emails.send_email_for_mark_as_reviewed(self.course, user)
self.save()
......@@ -744,10 +744,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
pass
@transition(field=name, source=CourseRunStateChoices.Approved, target=CourseRunStateChoices.Published)
def published(self, site):
emails.send_course_run_published_email(self.course_run, site)
def published(self):
emails.send_course_run_published_email(self.course_run)
def change_state(self, state, user, site=None):
def change_state(self, state, user):
"""
Change course run workflow state and ownership also send emails if required.
"""
......@@ -763,7 +763,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.review()
if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_send_for_review_course_run(self.course_run, user, site)
emails.send_email_for_send_for_review_course_run(self.course_run, user)
elif state == CourseRunStateChoices.Approved:
user_role = self.course_run.course.course_user_roles.get(user=user)
......@@ -772,11 +772,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.approved()
if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user, site)
emails.send_email_to_publisher(self.course_run, user, site)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user)
emails.send_email_to_publisher(self.course_run, user)
elif state == CourseRunStateChoices.Published:
self.published(site)
self.published()
self.save()
......
......@@ -4,7 +4,6 @@ from django.test import TestCase
from django.urls import reverse
from guardian.shortcuts import get_group_perms
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory
from course_discovery.apps.publisher.choices import PublisherUserRole
......@@ -19,7 +18,7 @@ USER_PASSWORD = 'password'
# pylint: disable=no-member
class AdminTests(SiteMixin, TestCase):
class AdminTests(TestCase):
""" Tests Admin page."""
def setUp(self):
......@@ -82,7 +81,7 @@ class AdminTests(SiteMixin, TestCase):
self.assertEqual(response.status_code, 200)
class OrganizationExtensionAdminTests(SiteMixin, TestCase):
class OrganizationExtensionAdminTests(TestCase):
""" Tests for OrganizationExtensionAdmin."""
def setUp(self):
......@@ -135,7 +134,7 @@ class OrganizationExtensionAdminTests(SiteMixin, TestCase):
@ddt.ddt
class OrganizationUserRoleAdminTests(SiteMixin, TestCase):
class OrganizationUserRoleAdminTests(TestCase):
""" Tests for OrganizationUserRoleAdmin."""
def setUp(self):
......
......@@ -2,13 +2,13 @@
import mock
from django.contrib.auth.models import Group
from django.contrib.sites.models import Site
from django.core import mail
from django.test import TestCase
from django.urls import reverse
from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests import toggle_switch
......@@ -21,7 +21,7 @@ from course_discovery.apps.publisher.tests import factories
from course_discovery.apps.publisher.tests.factories import UserAttributeFactory
class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
class StudioInstanceCreatedEmailTests(TestCase):
"""
Tests for the studio instance created email functionality.
"""
......@@ -50,14 +50,14 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
""" Verify that emails failure raise exception."""
with self.assertRaises(Exception) as ex:
emails.send_email_for_studio_instance_created(self.course_run, self.site)
emails.send_email_for_studio_instance_created(self.course_run)
error_message = 'Failed to send email notifications for course_run [{}]'.format(self.course_run.id)
self.assertEqual(ex.message, error_message)
def test_email_sent_successfully(self):
""" Verify that emails sent successfully for studio instance created."""
emails.send_email_for_studio_instance_created(self.course_run, self.site)
emails.send_email_for_studio_instance_created(self.course_run)
course_key = CourseKey.from_string(self.course_run.lms_course_id)
self.assert_email_sent(
reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id}),
......@@ -76,7 +76,7 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
self.assertIn('Enter course run content in Studio.', body)
self.assertIn('Thanks', body)
......@@ -89,7 +89,7 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
)
class CourseCreatedEmailTests(SiteMixin, TestCase):
class CourseCreatedEmailTests(TestCase):
""" Tests for the new course created email functionality. """
def setUp(self):
......@@ -116,7 +116,7 @@ class CourseCreatedEmailTests(SiteMixin, TestCase):
""" Verify that emails failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site)
emails.send_email_for_course_creation(self.course_run.course, self.course_run)
l.check(
(
emails.logger.name,
......@@ -130,7 +130,7 @@ class CourseCreatedEmailTests(SiteMixin, TestCase):
def test_email_sent_successfully(self):
""" Verify that studio instance request email sent successfully."""
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site)
emails.send_email_for_course_creation(self.course_run.course, self.course_run)
subject = 'Studio URL requested: {title}'.format(title=self.course_run.course.title)
self.assert_email_sent(subject)
......@@ -151,12 +151,12 @@ class CourseCreatedEmailTests(SiteMixin, TestCase):
user_attribute = UserAttributes.objects.get(user=self.user)
user_attribute.enable_email_notification = False
user_attribute.save()
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site)
emails.send_email_for_course_creation(self.course_run.course, self.course_run)
self.assertEqual(len(mail.outbox), 0)
class SendForReviewEmailTests(SiteMixin, TestCase):
class SendForReviewEmailTests(TestCase):
""" Tests for the send for review email functionality. """
def setUp(self):
......@@ -168,7 +168,7 @@ class SendForReviewEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_send_for_review(self.course_state.course, self.user, self.site)
emails.send_email_for_send_for_review(self.course_state.course, self.user)
l.check(
(
emails.logger.name,
......@@ -180,7 +180,7 @@ class SendForReviewEmailTests(SiteMixin, TestCase):
)
class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase):
class CourseMarkAsReviewedEmailTests(TestCase):
""" Tests for the mark as reviewed email functionality. """
def setUp(self):
......@@ -192,7 +192,7 @@ class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_mark_as_reviewed(self.course_state.course, self.user, self.site)
emails.send_email_for_mark_as_reviewed(self.course_state.course, self.user)
l.check(
(
emails.logger.name,
......@@ -204,7 +204,7 @@ class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase):
)
class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
class CourseRunSendForReviewEmailTests(TestCase):
""" Tests for the CourseRun send for review email functionality. """
def setUp(self):
......@@ -238,7 +238,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user, self.site)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user)
subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run)
self.assert_email_sent(subject, self.user_2)
......@@ -247,7 +247,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user_2, self.site)
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user_2)
subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run)
self.assert_email_sent(subject, self.user)
......@@ -255,7 +255,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_send_for_review_course_run(self.course_run, self.user, self.site)
emails.send_email_for_send_for_review_course_run(self.course_run, self.user)
l.check(
(
emails.logger.name,
......@@ -273,12 +273,12 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('View this course run in Publisher to review the changes or suggest edits.', body)
class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
class CourseRunMarkAsReviewedEmailTests(TestCase):
""" Tests for the CourseRun mark as reviewed email functionality. """
def setUp(self):
......@@ -311,7 +311,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user, self.site)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user)
self.assertEqual(len(mail.outbox), 0)
def test_email_sent_by_course_team(self):
......@@ -319,14 +319,14 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user_2, self.site)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user_2)
self.assert_email_sent(self.user)
def test_email_mark_as_reviewed_with_error(self):
""" Verify that email failure log error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, self.user, self.site)
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, self.user)
l.check(
(
emails.logger.name,
......@@ -342,7 +342,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
)
emails.send_email_to_publisher(self.course_run_state.course_run, self.user, self.site)
emails.send_email_to_publisher(self.course_run_state.course_run, self.user)
self.assert_email_sent(self.user_3)
def test_email_to_publisher_with_error(self):
......@@ -350,7 +350,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with LogCapture(emails.logger.name) as l:
emails.send_email_to_publisher(self.course_run, self.user_3, self.site)
emails.send_email_to_publisher(self.course_run, self.user_3)
l.check(
(
emails.logger.name,
......@@ -375,12 +375,12 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('The review for this course run is complete.', body)
class CourseRunPreviewEmailTests(SiteMixin, TestCase):
class CourseRunPreviewEmailTests(TestCase):
"""
Tests for the course preview email functionality.
"""
......@@ -414,7 +414,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
lms_course_id = 'course-v1:edX+DemoX+Demo_Course'
self.run_state.course_run.lms_course_id = lms_course_id
emails.send_email_preview_accepted(self.run_state.course_run, self.site)
emails.send_email_preview_accepted(self.run_state.course_run)
course_key = CourseKey.from_string(lms_course_id)
subject = 'Publication requested: {course_name} {run_number}'.format(
......@@ -426,7 +426,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.run_state.course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('You can now publish this About page.', body)
......@@ -440,7 +440,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
with self.assertRaises(Exception) as ex:
self.assertEqual(str(ex.exception), message)
with LogCapture(emails.logger.name) as l:
emails.send_email_preview_accepted(self.run_state.course_run, self.site)
emails.send_email_preview_accepted(self.run_state.course_run)
l.check(
(
emails.logger.name,
......@@ -457,7 +457,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
course_run.lms_course_id = 'course-v1:testX+testX1.0+2017T1'
course_run.save()
emails.send_email_preview_page_is_available(course_run, self.site)
emails.send_email_preview_page_is_available(course_run)
course_key = CourseKey.from_string(course_run.lms_course_id)
subject = 'Review requested: Preview for {course_name} {run_number}'.format(
......@@ -469,7 +469,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('A preview is now available for the', body)
......@@ -477,7 +477,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
""" Verify that exception raised on email failure."""
with self.assertRaises(Exception) as ex:
emails.send_email_preview_page_is_available(self.run_state.course_run, self.site)
emails.send_email_preview_page_is_available(self.run_state.course_run)
error_message = 'Failed to send email notifications for preview available of course-run {}'.format(
self.run_state.course_run.id
)
......@@ -486,19 +486,19 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
def test_preview_available_email_with_notification_disabled(self):
""" Verify that email not sent if notification disabled by user."""
factories.UserAttributeFactory(user=self.course.course_team_admin, enable_email_notification=False)
emails.send_email_preview_page_is_available(self.run_state.course_run, self.site)
emails.send_email_preview_page_is_available(self.run_state.course_run)
self.assertEqual(len(mail.outbox), 0)
def test_preview_accepted_email_with_notification_disabled(self):
""" Verify that preview accepted email not sent if notification disabled by user."""
factories.UserAttributeFactory(user=self.course.publisher, enable_email_notification=False)
emails.send_email_preview_accepted(self.run_state.course_run, self.site)
emails.send_email_preview_accepted(self.run_state.course_run)
self.assertEqual(len(mail.outbox), 0)
class CourseRunPublishedEmailTests(SiteMixin, TestCase):
class CourseRunPublishedEmailTests(TestCase):
"""
Tests for course run published email functionality.
"""
......@@ -527,7 +527,7 @@ class CourseRunPublishedEmailTests(SiteMixin, TestCase):
"""
self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2'
self.course_run.save()
emails.send_course_run_published_email(self.course_run, self.site)
emails.send_course_run_published_email(self.course_run)
course_key = CourseKey.from_string(self.course_run.lms_course_id)
subject = 'Publication complete: About page for {course_name} {run_number}'.format(
......@@ -550,11 +550,11 @@ class CourseRunPublishedEmailTests(SiteMixin, TestCase):
)
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with self.assertRaises(Exception) as ex:
emails.send_course_run_published_email(self.course_run, self.site)
emails.send_course_run_published_email(self.course_run)
self.assertEqual(str(ex.exception), message)
class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
class CourseChangeRoleAssignmentEmailTests(TestCase):
"""
Tests email functionality for course role assignment changed.
"""
......@@ -575,7 +575,7 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
"""
Verify that course role assignment chnage email functionality works fine.
"""
emails.send_change_role_assignment_email(self.marketing_role, self.user, self.site)
emails.send_change_role_assignment_email(self.marketing_role, self.user)
expected_subject = '{role_name} changed for {course_title}'.format(
role_name=self.marketing_role.get_role_display().lower(),
course_title=self.course.title
......@@ -589,7 +589,7 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('has changed.', body)
......@@ -603,11 +603,11 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
)
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with self.assertRaises(Exception) as ex:
emails.send_change_role_assignment_email(self.marketing_role, self.user, self.site)
emails.send_change_role_assignment_email(self.marketing_role, self.user)
self.assertEqual(str(ex.exception), message)
class SEOReviewEmailTests(SiteMixin, TestCase):
class SEOReviewEmailTests(TestCase):
""" Tests for the seo review email functionality. """
def setUp(self):
......@@ -626,7 +626,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_seo_review(self.course, self.site)
emails.send_email_for_seo_review(self.course)
l.check(
(
emails.logger.name,
......@@ -642,7 +642,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase):
Verify that seo review email functionality works fine.
"""
factories.CourseUserRoleFactory(course=self.course, role=PublisherUserRole.ProjectCoordinator)
emails.send_email_for_seo_review(self.course, self.site)
emails.send_email_for_seo_review(self.course)
expected_subject = 'Legal review requested: {title}'.format(title=self.course.title)
self.assertEqual(len(mail.outbox), 1)
......@@ -652,7 +652,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body)
self.assertIn('determine OFAC status', body)
......@@ -671,7 +671,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
)
self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2'
self.course_run.save()
emails.send_email_for_published_course_run_editing(self.course_run, self.site)
emails.send_email_for_published_course_run_editing(self.course_run)
course_key = CourseKey.from_string(self.course_run.lms_course_id)
......@@ -692,7 +692,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
""" Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l:
emails.send_email_for_published_course_run_editing(self.course_run, self.site)
emails.send_email_for_published_course_run_editing(self.course_run)
l.check(
(
emails.logger.name,
......
......@@ -6,7 +6,6 @@ from django.urls import reverse
from django_fsm import TransitionNotAllowed
from guardian.shortcuts import assign_perm
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory
......@@ -511,7 +510,7 @@ class GroupOrganizationTests(TestCase):
@ddt.ddt
class CourseStateTests(SiteMixin, TestCase):
class CourseStateTests(TestCase):
""" Tests for the publisher `CourseState` model. """
@classmethod
......@@ -549,7 +548,7 @@ class CourseStateTests(SiteMixin, TestCase):
"""
self.assertNotEqual(self.course_state.name, state)
self.course_state.change_state(state=state, user=self.user, site=self.site)
self.course_state.change_state(state=state, user=self.user)
self.assertEqual(self.course_state.name, state)
......@@ -562,7 +561,7 @@ class CourseStateTests(SiteMixin, TestCase):
self.assertEqual(self.course_state.name, CourseStateChoices.Draft)
with self.assertRaises(TransitionNotAllowed):
self.course_state.change_state(state=CourseStateChoices.Review, user=self.user, site=self.site)
self.course_state.change_state(state=CourseStateChoices.Review, user=self.user)
def test_can_send_for_review(self):
"""
......@@ -645,7 +644,7 @@ class CourseStateTests(SiteMixin, TestCase):
@ddt.ddt
class CourseRunStateTests(SiteMixin, TestCase):
class CourseRunStateTests(TestCase):
""" Tests for the publisher `CourseRunState` model. """
@classmethod
......@@ -704,7 +703,7 @@ class CourseRunStateTests(SiteMixin, TestCase):
Verify that we can change course-run state according to workflow.
"""
self.assertNotEqual(self.course_run_state.name, state)
self.course_run_state.change_state(state=state, user=self.user, site=self.site)
self.course_run_state.change_state(state=state, user=self.user)
self.assertEqual(self.course_run_state.name, state)
def test_with_invalid_parent_course_state(self):
......
......@@ -19,13 +19,11 @@ from opaque_keys.edx.keys import CourseKey
from pytz import timezone
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import (CourseFactory, OrganizationFactory, PersonFactory,
SubjectFactory)
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, OrganizationFactory, PersonFactory
from course_discovery.apps.ietf_language_tags.models import LanguageTag
from course_discovery.apps.publisher.choices import (CourseRunStateChoices, CourseStateChoices, InternalUserRole,
PublisherUserRole)
......@@ -44,7 +42,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt
class CreateCourseViewTests(SiteMixin, TestCase):
class CreateCourseViewTests(TestCase):
""" Tests for the publisher `CreateCourseView`. """
def setUp(self):
......@@ -63,6 +61,7 @@ class CreateCourseViewTests(SiteMixin, TestCase):
self.course = factories.CourseFactory()
self.course.organizations.add(self.organization_extension.organization)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD)
# creating default organizations roles
......@@ -270,7 +269,7 @@ class CreateCourseViewTests(SiteMixin, TestCase):
)
class CreateCourseRunViewTests(SiteMixin, TestCase):
class CreateCourseRunViewTests(TestCase):
""" Tests for the publisher `UpdateCourseRunView`. """
def setUp(self):
......@@ -300,6 +299,7 @@ class CreateCourseRunViewTests(SiteMixin, TestCase):
current_datetime = datetime.now(timezone('US/Central'))
self.course_run_dict['start'] = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
self.course_run_dict['end'] = (current_datetime + timedelta(days=3)).strftime('%Y-%m-%d %H:%M:%S')
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD)
def _pop_valuse_from_dict(self, data_dict, key_list):
......@@ -562,7 +562,7 @@ class CreateCourseRunViewTests(SiteMixin, TestCase):
@ddt.ddt
class CourseRunDetailTests(SiteMixin, TestCase):
class CourseRunDetailTests(TestCase):
""" Tests for the course-run detail view. """
def setUp(self):
......@@ -763,8 +763,9 @@ class CourseRunDetailTests(SiteMixin, TestCase):
"""
self.client.logout()
self.client.login(username=self.user.username, password=USER_PASSWORD)
site = Site.objects.get(pk=settings.SITE_ID)
comment = CommentFactory(content_object=self.course_run, user=self.user, site=self.site)
comment = CommentFactory(content_object=self.course_run, user=self.user, site=site)
response = self.client.get(self.page_url)
self.assertEqual(response.status_code, 200)
self._assert_credits_seats(response, self.wrapped_course_run.credit_seat)
......@@ -778,7 +779,7 @@ class CourseRunDetailTests(SiteMixin, TestCase):
# test decline comment appearing on detail page also.
decline_comment = CommentFactory(
content_object=self.course_run,
user=self.user, site=self.site, comment_type=CommentTypeChoices.Decline_Preview
user=self.user, site=site, comment_type=CommentTypeChoices.Decline_Preview
)
response = self.client.get(self.page_url)
self.assertContains(response, decline_comment.comment)
......@@ -1232,12 +1233,12 @@ class CourseRunDetailTests(SiteMixin, TestCase):
# pylint: disable=attribute-defined-outside-init
@ddt.ddt
class DashboardTests(SiteMixin, TestCase):
class DashboardTests(TestCase):
""" Tests for the `Dashboard`. """
def setUp(self):
super(DashboardTests, self).setUp()
Site.objects.exclude(id=self.site.id).delete()
self.group_internal = Group.objects.get(name=INTERNAL_USER_GROUP_NAME)
self.group_project_coordinator = Group.objects.get(name=PROJECT_COORDINATOR_GROUP_NAME)
self.group_reviewer = Group.objects.get(name=REVIEWER_GROUP_NAME)
......@@ -1276,18 +1277,7 @@ class DashboardTests(SiteMixin, TestCase):
def _create_course_assign_role(self, state, user, role):
""" Create course-run-state, course-user-role and return course-run. """
course = factories.CourseFactory(
primary_subject=SubjectFactory(partner=self.partner),
secondary_subject=SubjectFactory(partner=self.partner),
tertiary_subject=SubjectFactory(partner=self.partner)
)
course_run = factories.CourseRunFactory(course=course)
course_run_state = factories.CourseRunStateFactory(
name=state,
owner_role=role,
course_run=course_run
)
course_run_state = factories.CourseRunStateFactory(name=state, owner_role=role)
factories.CourseUserRoleFactory(course=course_run_state.course_run.course, role=role, user=user)
return course_run_state.course_run
......@@ -1311,7 +1301,7 @@ class DashboardTests(SiteMixin, TestCase):
self.client.logout()
self.client.login(username=UserFactory(), password=USER_PASSWORD)
response = self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=12
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=11
)
self._assert_tabs_with_roles(response)
......@@ -1319,7 +1309,7 @@ class DashboardTests(SiteMixin, TestCase):
def test_with_internal_group(self, tab):
""" Verify that internal user can see courses assigned to the groups. """
response = self.assert_dashboard_response(
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=24
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=23
)
self.assertContains(response, '<li role="tab" id="tab-{tab}" class="tab"'.format(tab=tab))
......@@ -1334,7 +1324,7 @@ class DashboardTests(SiteMixin, TestCase):
self.course_run_1.course.organizations.add(self.organization_extension.organization)
response = self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=12
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=11
)
self._assert_tabs_with_roles(response)
......@@ -1359,14 +1349,14 @@ class DashboardTests(SiteMixin, TestCase):
)
response = self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=2, preview_count=1, queries_executed=22
studio_count=0, published_count=0, progress_count=2, preview_count=1, queries_executed=21
)
self._assert_tabs_with_roles(response)
def test_studio_request_course_runs_as_pc(self):
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
response = self.assert_dashboard_response(
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=24
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=23
)
self._assert_tabs_with_roles(response)
......@@ -1374,7 +1364,7 @@ class DashboardTests(SiteMixin, TestCase):
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
self.user1.groups.remove(self.group_project_coordinator)
response = self.assert_dashboard_response(
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=21
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=20
)
self._assert_tabs_with_roles(response)
......@@ -1385,7 +1375,7 @@ class DashboardTests(SiteMixin, TestCase):
self.course_run_2.lms_course_id = 'test-2'
self.course_run_2.save()
response = self.assert_dashboard_response(
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=22
studio_count=0, published_count=1, progress_count=2, preview_count=1, queries_executed=21
)
self.assertContains(response, 'No courses are currently ready for a Studio URL.')
......@@ -1394,7 +1384,7 @@ class DashboardTests(SiteMixin, TestCase):
self.course_run_3.course_run_state.name = CourseRunStateChoices.Draft
self.course_run_3.course_run_state.save()
response = self.assert_dashboard_response(
studio_count=3, published_count=0, progress_count=3, preview_count=1, queries_executed=25
studio_count=3, published_count=0, progress_count=3, preview_count=1, queries_executed=24
)
self.assertContains(response, 'No About pages have been published yet')
self._assert_tabs_with_roles(response)
......@@ -1402,7 +1392,7 @@ class DashboardTests(SiteMixin, TestCase):
def test_published_course_runs(self):
""" Verify that published tab loads course runs list. """
response = self.assert_dashboard_response(
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=24
studio_count=2, published_count=1, progress_count=2, preview_count=1, queries_executed=23
)
self.assertContains(response, self.table_class.format(id='published'))
self.assertContains(response, 'About pages for the following course runs have been published in the')
......@@ -1420,7 +1410,7 @@ class DashboardTests(SiteMixin, TestCase):
# Verify that user cannot see any published course run
self.assert_dashboard_response(
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=16
studio_count=0, published_count=0, progress_count=0, preview_count=0, queries_executed=15
)
# assign user course role
......@@ -1444,14 +1434,14 @@ class DashboardTests(SiteMixin, TestCase):
publisher_admin.groups.add(Group.objects.get(name=ADMIN_GROUP_NAME))
self.client.login(username=publisher_admin.username, password=USER_PASSWORD)
response = self.assert_dashboard_response(
studio_count=4, published_count=1, progress_count=3, preview_count=1, queries_executed=21
studio_count=4, published_count=1, progress_count=3, preview_count=1, queries_executed=20
)
self._assert_tabs_with_roles(response)
def test_with_preview_ready_course_runs(self):
""" Verify that preview ready tabs loads the course runs list. """
response = self.assert_dashboard_response(
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=24
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=23
)
self.assertContains(response, self.table_class.format(id='preview'))
self.assertContains(response, 'About page previews for the following course runs are available for course team')
......@@ -1463,7 +1453,7 @@ class DashboardTests(SiteMixin, TestCase):
self.course_run_2.course_run_state.name = CourseRunStateChoices.Draft
self.course_run_2.course_run_state.save()
response = self.assert_dashboard_response(
studio_count=2, preview_count=0, progress_count=3, published_count=1, queries_executed=23
studio_count=2, preview_count=0, progress_count=3, published_count=1, queries_executed=22
)
self._assert_tabs_with_roles(response)
......@@ -1472,7 +1462,7 @@ class DashboardTests(SiteMixin, TestCase):
preview url is added or not.
"""
response = self.assert_dashboard_response(
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=24
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=23
)
self._assert_tabs_with_roles(response)
......@@ -1487,7 +1477,7 @@ class DashboardTests(SiteMixin, TestCase):
def test_with_in_progress_course_runs(self):
""" Verify that in progress tabs loads the course runs list. """
response = self.assert_dashboard_response(
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=24
studio_count=2, preview_count=1, progress_count=2, published_count=1, queries_executed=23
)
self.assertContains(response, self.table_class.format(id='in-progress'))
self._assert_tabs_with_roles(response)
......@@ -1523,7 +1513,7 @@ class DashboardTests(SiteMixin, TestCase):
self.client.logout()
self.client.login(username=pc_user.username, password=USER_PASSWORD)
with self.assertNumQueries(12):
with self.assertNumQueries(11):
response = self.client.get(self.page_url)
for tab in ['progress', 'preview', 'studio', 'published']:
......@@ -1533,7 +1523,7 @@ class DashboardTests(SiteMixin, TestCase):
"""
Verify that site_name is available in context.
"""
with self.assertNumQueries(24):
with self.assertNumQueries(23):
response = self.client.get(self.page_url)
site = Site.objects.first()
self.assertEqual(response.context['site_name'], site.name)
......@@ -1552,12 +1542,13 @@ class DashboardTests(SiteMixin, TestCase):
course_run.course_run_state.owner_role = PublisherUserRole.CourseTeam
course_run.course_run_state.save()
with self.assertNumQueries(26):
with self.assertNumQueries(25):
response = self.client.get(self.page_url)
site = Site.objects.first()
self._assert_filter_counts(response, 'All', 3)
self._assert_filter_counts(response, 'With Course Team', 2)
self._assert_filter_counts(response, 'With {site_name}'.format(site_name=self.site.name), 1)
self._assert_filter_counts(response, 'With {site_name}'.format(site_name=site.name), 1)
def _assert_filter_counts(self, response, expected_label, count):
"""
......@@ -1568,7 +1559,7 @@ class DashboardTests(SiteMixin, TestCase):
self.assertContains(response, expected_count, count=1)
class ToggleEmailNotificationTests(SiteMixin, TestCase):
class ToggleEmailNotificationTests(TestCase):
""" Tests for `ToggleEmailNotification` view. """
def setUp(self):
......@@ -1601,7 +1592,7 @@ class ToggleEmailNotificationTests(SiteMixin, TestCase):
self.assertEqual(is_email_notification_enabled(user), is_enabled)
class CourseListViewTests(SiteMixin, TestCase):
class CourseListViewTests(TestCase):
""" Tests for `CourseListView` """
def setUp(self):
......@@ -1615,12 +1606,12 @@ class CourseListViewTests(SiteMixin, TestCase):
def test_courses_with_no_courses(self):
""" Verify that user cannot see any course on course list page. """
self.assert_course_list_page(course_count=0, queries_executed=9)
self.assert_course_list_page(course_count=0, queries_executed=8)
def test_courses_with_admin(self):
""" Verify that admin user can see all courses on course list page. """
self.user.groups.add(Group.objects.get(name=ADMIN_GROUP_NAME))
self.assert_course_list_page(course_count=10, queries_executed=32)
self.assert_course_list_page(course_count=10, queries_executed=31)
def test_courses_with_course_user_role(self):
""" Verify that internal user can see course on course list page. """
......@@ -1628,7 +1619,7 @@ class CourseListViewTests(SiteMixin, TestCase):
for course in self.courses:
factories.CourseUserRoleFactory(course=course, user=self.user, role=InternalUserRole.Publisher)
self.assert_course_list_page(course_count=10, queries_executed=33)
self.assert_course_list_page(course_count=10, queries_executed=32)
def test_courses_with_permission(self):
""" Verify that user can see course with permission on course list page. """
......@@ -1639,7 +1630,7 @@ class CourseListViewTests(SiteMixin, TestCase):
course.organizations.add(organization_extension.organization)
assign_perm(OrganizationExtension.VIEW_COURSE, organization_extension.group, organization_extension)
self.assert_course_list_page(course_count=10, queries_executed=65)
self.assert_course_list_page(course_count=10, queries_executed=64)
def assert_course_list_page(self, course_count, queries_executed):
""" Dry method to assert course list page content. """
......@@ -1685,13 +1676,13 @@ class CourseListViewTests(SiteMixin, TestCase):
toggle_switch('publisher_hide_features_for_pilot', False)
with self.assertNumQueries(22):
with self.assertNumQueries(21):
response = self.client.get(self.courses_url)
self.assertContains(response, 'Edit')
class CourseDetailViewTests(SiteMixin, TestCase):
class CourseDetailViewTests(TestCase):
""" Tests for the course detail view. """
def setUp(self):
......@@ -2123,7 +2114,7 @@ class CourseDetailViewTests(SiteMixin, TestCase):
@ddt.ddt
class CourseEditViewTests(SiteMixin, TestCase):
class CourseEditViewTests(TestCase):
""" Tests for the course edit view. """
def setUp(self):
......@@ -2541,7 +2532,7 @@ class CourseEditViewTests(SiteMixin, TestCase):
@ddt.ddt
class CourseRunEditViewTests(SiteMixin, TestCase):
class CourseRunEditViewTests(TestCase):
""" Tests for the course run edit view. """
def setUp(self):
......@@ -2559,6 +2550,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase):
self.seat = factories.SeatFactory(course_run=self.course_run, type=Seat.VERIFIED, price=2)
self.course.organizations.add(self.organization_extension.organization)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD)
current_datetime = datetime.now(timezone('US/Central'))
self.start_date_time = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
......@@ -2841,7 +2833,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body)
def test_studio_instance_with_course_team(self):
......@@ -3070,7 +3062,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
class CourseRevisionViewTests(SiteMixin, TestCase):
class CourseRevisionViewTests(TestCase):
""" Tests for CourseReview"""
def setUp(self):
......@@ -3122,7 +3114,7 @@ class CourseRevisionViewTests(SiteMixin, TestCase):
return self.client.get(path=revision_path)
class CreateRunFromDashboardViewTests(SiteMixin, TestCase):
class CreateRunFromDashboardViewTests(TestCase):
""" Tests for the publisher `CreateRunFromDashboardView`. """
def setUp(self):
......@@ -3222,7 +3214,7 @@ class CreateRunFromDashboardViewTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject)
class CreateAdminImportCourseTest(SiteMixin, TestCase):
class CreateAdminImportCourseTest(TestCase):
""" Tests for the publisher `CreateAdminImportCourse`. """
def setUp(self):
......
......@@ -394,16 +394,14 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView):
if latest_run and latest_run.course_run_state.name == CourseRunStateChoices.Published:
# If latest run of this course is published send an email to Publisher and don't change state.
send_email_for_published_course_run_editing(latest_run, self.request.site)
send_email_for_published_course_run_editing(latest_run)
else:
user_role = self.object.course_user_roles.get(user=user)
# Change course state to draft if marketing not yet reviewed or
# if marketing person updating the course.
if not self.object.course_state.marketing_reviewed or user_role.role == PublisherUserRole.MarketingReviewer:
if self.object.course_state.name != CourseStateChoices.Draft:
self.object.course_state.change_state(
state=CourseStateChoices.Draft, user=user, site=self.request.site
)
self.object.course_state.change_state(state=CourseStateChoices.Draft, user=user)
# Change ownership if user role not equal to owner role.
if self.object.course_state.owner_role != user_role.role:
......@@ -601,7 +599,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView):
)
messages.success(request, success_msg)
emails.send_email_for_course_creation(parent_course, course_run, request.site)
emails.send_email_for_course_creation(parent_course, course_run)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as error: # pylint: disable=broad-except
# pylint: disable=no-member
......@@ -742,10 +740,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state = course_run.course_run_state
if course_run_state.name not in immutable_states:
course_run_state.change_state(state=CourseStateChoices.Draft, user=user, site=request.site)
course_run_state.change_state(state=CourseStateChoices.Draft, user=user)
if course_run.lms_course_id and lms_course_id != course_run.lms_course_id:
emails.send_email_for_studio_instance_created(course_run, site=request.site)
emails.send_email_for_studio_instance_created(course_run)
# pylint: disable=no-member
messages.success(request, _('Course run updated successfully.'))
......@@ -759,7 +757,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state.change_owner_role(user_role)
if CourseRunStateChoices.Published == course_run_state.name:
send_email_for_published_course_run_editing(course_run, request.site)
send_email_for_published_course_run_editing(course_run)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as e: # pylint: disable=broad-except
......
......@@ -3,7 +3,6 @@ import json
from django.test import TestCase
from rest_framework.reverse import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE
from course_discovery.apps.publisher.tests.factories import CourseRunFactory
......@@ -12,7 +11,7 @@ from course_discovery.apps.publisher_comments.models import Comments
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class PostCommentTests(SiteMixin, TestCase):
class PostCommentTests(TestCase):
def generate_data(self, obj):
"""Generate data for the form."""
......@@ -40,7 +39,7 @@ class PostCommentTests(SiteMixin, TestCase):
self.assertEqual(comment.user_email, generated_data['email'])
class UpdateCommentTests(SiteMixin, TestCase):
class UpdateCommentTests(TestCase):
def setUp(self):
super(UpdateCommentTests, self).setUp()
......
from django.conf import settings
from django.contrib.sites.models import Site
from django.test import TestCase
from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import factories
from course_discovery.apps.publisher_comments.forms import CommentsAdminForm
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class AdminTests(SiteMixin, TestCase):
class AdminTests(TestCase):
""" Tests Admin page and customize form."""
def setUp(self):
super(AdminTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.login(username=self.user.username, password=USER_PASSWORD)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.course = factories.CourseFactory()
self.comment = CommentFactory(content_object=self.course, user=self.user, site=self.site)
......
import ddt
import mock
from django.conf import settings
from django.contrib.sites.models import Site
from django.core import mail
from django.test import TestCase
from django.urls import reverse
from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.publisher.choices import PublisherUserRole
......@@ -19,7 +20,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt
class CommentsEmailTests(SiteMixin, TestCase):
class CommentsEmailTests(TestCase):
""" Tests for the e-mail functionality for course, course-run and seats. """
def setUp(self):
......@@ -29,6 +30,8 @@ class CommentsEmailTests(SiteMixin, TestCase):
self.user_2 = UserFactory()
self.user_3 = UserFactory()
self.site = Site.objects.get(pk=settings.SITE_ID)
self.organization_extension = factories.OrganizationExtensionFactory()
self.seat = factories.SeatFactory()
......
......@@ -79,7 +79,6 @@ MIDDLEWARE_CLASSES = (
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.contrib.sites.middleware.CurrentSiteMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'social_django.middleware.SocialAuthExceptionMiddleware',
'waffle.middleware.WaffleMiddleware',
......@@ -476,6 +475,8 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20
DEFAULT_PARTNER_ID = None
# See: https://docs.djangoproject.com/en/dev/ref/settings/#site-id
SITE_ID = 1
COMMENTS_APP = 'course_discovery.apps.publisher_comments'
TAGGIT_CASE_INSENSITIVE = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment