Commit fc642742 by Vedran Karacic Committed by Vedran Karačić

Revert "Remove SITE_ID from settings."

This reverts commit a804e243.
parent a804e243
from django.test import RequestFactory from django.conf import settings
from django.contrib.sites.models import Site
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory
class SiteMixin(object): class PartnerMixin(object):
def setUp(self): def setUp(self):
super(SiteMixin, self).setUp() super(PartnerMixin, self).setUp()
domain = 'testserver.fake' Site.objects.all().delete()
self.client = self.client_class(SERVER_NAME=domain) self.site = SiteFactory(id=settings.SITE_ID)
self.site = SiteFactory(domain=domain)
self.partner = PartnerFactory(site=self.site) self.partner = PartnerFactory(site=self.site)
self.request = RequestFactory(SERVER_NAME=self.site.domain).get('')
self.request.site = self.site
...@@ -21,7 +21,7 @@ from course_discovery.apps.api.serializers import ( ...@@ -21,7 +21,7 @@ from course_discovery.apps.api.serializers import (
ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer, ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer,
TypeaheadProgramSearchSerializer, VideoSerializer TypeaheadProgramSearchSerializer, VideoSerializer
) )
from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.api.tests.mixins import PartnerMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
...@@ -97,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase): ...@@ -97,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
class MinimalCourseSerializerTests(SiteMixin, TestCase): class MinimalCourseSerializerTests(PartnerMixin, TestCase):
serializer_class = MinimalCourseSerializer serializer_class = MinimalCourseSerializer
def get_expected_data(self, course, request): def get_expected_data(self, course, request):
......
...@@ -12,7 +12,7 @@ from course_discovery.apps.api.serializers import ( ...@@ -12,7 +12,7 @@ from course_discovery.apps.api.serializers import (
CourseWithProgramsSerializer, FlattenedCourseRunWithCourseSerializer, MinimalProgramSerializer, CourseWithProgramsSerializer, FlattenedCourseRunWithCourseSerializer, MinimalProgramSerializer,
OrganizationSerializer, PersonSerializer, ProgramSerializer, ProgramTypeSerializer OrganizationSerializer, PersonSerializer, ProgramSerializer, ProgramTypeSerializer
) )
from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.api.tests.mixins import PartnerMixin
class SerializationMixin(object): class SerializationMixin(object):
...@@ -92,5 +92,5 @@ class OAuth2Mixin(object): ...@@ -92,5 +92,5 @@ class OAuth2Mixin(object):
) )
class APITestCase(SiteMixin, RestAPITestCase): class APITestCase(PartnerMixin, RestAPITestCase):
pass pass
...@@ -7,9 +7,10 @@ import ddt ...@@ -7,9 +7,10 @@ import ddt
import pytz import pytz
from lxml import etree from lxml import etree
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import AffiliateWindowSerializer from course_discovery.apps.api.serializers import AffiliateWindowSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
...@@ -45,7 +46,8 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP ...@@ -45,7 +46,8 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
def test_affiliate_with_supported_seats(self): def test_affiliate_with_supported_seats(self):
""" Verify that endpoint returns course runs for verified and professional seats only. """ """ Verify that endpoint returns course runs for verified and professional seats only. """
response = self.client.get(self.affiliate_url) with self.assertNumQueries(9):
response = self.client.get(self.affiliate_url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
root = ET.fromstring(response.content) root = ET.fromstring(response.content)
......
...@@ -8,9 +8,10 @@ import pytz ...@@ -8,9 +8,10 @@ import pytz
import responses import responses
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, OAuth2Mixin, SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import OAuth2Mixin, SerializationMixin
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
...@@ -30,7 +31,6 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -30,7 +31,6 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
def setUp(self): def setUp(self):
super(CatalogViewSetTests, self).setUp() super(CatalogViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.force_authenticate(self.user) self.client.force_authenticate(self.user)
self.catalog = CatalogFactory(query='title:abc*') self.catalog = CatalogFactory(query='title:abc*')
enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=30) enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=30)
......
...@@ -21,7 +21,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -21,7 +21,6 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
def setUp(self): def setUp(self):
super(CourseViewSetTests, self).setUp() super(CourseViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
self.course = CourseFactory(partner=self.partner) self.course = CourseFactory(partner=self.partner)
......
...@@ -13,7 +13,6 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -13,7 +13,6 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def setUp(self): def setUp(self):
super(OrganizationViewSetTests, self).setUp() super(OrganizationViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
def test_authentication(self): def test_authentication(self):
...@@ -29,14 +28,16 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -29,14 +28,16 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
""" Asserts the response data (only) contains the expected organizations. """ """ Asserts the response data (only) contains the expected organizations. """
actual = response.data actual = response.data
serializer_data = self.serialize_organization(organizations, many=many) serializer_data = self.serialize_organization(organizations, many=many)
if many: if many:
actual = actual['results'] actual = actual['results']
actual = sorted(actual, key=lambda k: k['uuid'])
serializer_data = sorted(serializer_data, key=lambda k: k['uuid'])
self.assertCountEqual(actual, serializer_data) self.assertEqual(actual, serializer_data)
def assert_list_uuid_filter(self, organizations, expected_query_count): def assert_list_uuid_filter(self, organizations, expected_query_count):
""" Asserts the list endpoint supports filtering by UUID. """ """ Asserts the list endpoint supports filtering by UUID. """
organizations = sorted(organizations, key=lambda o: o.created)
with self.assertNumQueries(expected_query_count): with self.assertNumQueries(expected_query_count):
uuids = ','.join([organization.uuid.hex for organization in organizations]) uuids = ','.join([organization.uuid.hex for organization in organizations])
url = '{root}?uuids={uuids}'.format(root=self.list_path, uuids=uuids) url = '{root}?uuids={uuids}'.format(root=self.list_path, uuids=uuids)
......
...@@ -7,8 +7,7 @@ from rest_framework.reverse import reverse ...@@ -7,8 +7,7 @@ from rest_framework.reverse import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.api.v1.tests.test_views.mixins import PartnerMixin, SerializationMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.views.people import logger as people_logger from course_discovery.apps.api.v1.views.people import logger as people_logger
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.models import Person from course_discovery.apps.course_metadata.models import Person
...@@ -20,7 +19,7 @@ User = get_user_model() ...@@ -20,7 +19,7 @@ User = get_user_model()
@ddt.ddt @ddt.ddt
class PersonViewSetTests(SerializationMixin, SiteMixin, APITestCase): class PersonViewSetTests(SerializationMixin, PartnerMixin, APITestCase):
""" Tests for the person resource. """ """ Tests for the person resource. """
people_list_url = reverse('api:v1:person-list') people_list_url = reverse('api:v1:person-list')
......
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.models import ProgramType from course_discovery.apps.course_metadata.models import ProgramType
from course_discovery.apps.course_metadata.tests.factories import ProgramTypeFactory from course_discovery.apps.course_metadata.tests.factories import ProgramTypeFactory
......
...@@ -24,7 +24,6 @@ class ProgramViewSetTests(SerializationMixin, APITestCase): ...@@ -24,7 +24,6 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def setUp(self): def setUp(self):
super(ProgramViewSetTests, self).setUp() super(ProgramViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
# Clear the cache between test cases, so they don't interfere with each other. # Clear the cache between test cases, so they don't interfere with each other.
......
...@@ -155,9 +155,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT ...@@ -155,9 +155,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
return course_run, response_data return course_run, response_data
def build_facet_url(self, params): def build_facet_url(self, params):
return 'http://testserver.fake{path}?{query}'.format( return 'http://testserver{path}?{query}'.format(path=self.faceted_path, query=urllib.parse.urlencode(params))
path=self.faceted_path, query=urllib.parse.urlencode(params)
)
def test_invalid_query_facet(self): def test_invalid_query_facet(self):
""" Verify the endpoint returns HTTP 400 if an invalid facet is requested. """ """ Verify the endpoint returns HTTP 400 if an invalid facet is requested. """
......
...@@ -3,11 +3,10 @@ import json ...@@ -3,11 +3,10 @@ import json
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
class UserAutocompleteTests(SiteMixin, TestCase): class UserAutocompleteTests(TestCase):
""" Tests for user autocomplete lookups.""" """ Tests for user autocomplete lookups."""
def setUp(self): def setUp(self):
......
...@@ -2,13 +2,12 @@ from django.core.cache import cache ...@@ -2,13 +2,12 @@ from django.core.cache import cache
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import UserThrottleRate from course_discovery.apps.core.models import UserThrottleRate
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.throttles import OverridableUserRateThrottle from course_discovery.apps.core.throttles import OverridableUserRateThrottle
class RateLimitingTest(SiteMixin, APITestCase): class RateLimitingTest(APITestCase):
""" """
Testing rate limiting of API calls. Testing rate limiting of API calls.
""" """
......
...@@ -9,23 +9,33 @@ from django.test.utils import override_settings ...@@ -9,23 +9,33 @@ from django.test.utils import override_settings
from django.urls import reverse from django.urls import reverse
from django.utils.encoding import force_text from django.utils.encoding import force_text
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.constants import Status from course_discovery.apps.core.constants import Status
from course_discovery.apps.core.views import get_database_status
User = get_user_model() User = get_user_model()
class HealthTests(SiteMixin, TestCase): class HealthTests(TestCase):
"""Tests of the health endpoint.""" """Tests of the health endpoint."""
def test_getting_database_ok_status(self):
"""Method should return the OK status."""
status = get_database_status()
self.assertEqual(status, Status.OK)
def test_getting_database_unavailable_status(self):
"""Method should return the unavailable status when a DatabaseError occurs."""
with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError):
status = get_database_status()
self.assertEqual(status, Status.UNAVAILABLE)
def test_all_services_available(self): def test_all_services_available(self):
"""Test that the endpoint reports when all services are healthy.""" """Test that the endpoint reports when all services are healthy."""
self._assert_health(200, Status.OK, Status.OK) self._assert_health(200, Status.OK, Status.OK)
@mock.patch('django.contrib.sites.middleware.get_current_site', mock.Mock(return_value=None))
def test_database_outage(self): def test_database_outage(self):
"""Test that the endpoint reports when the database is unavailable.""" """Test that the endpoint reports when the database is unavailable."""
with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError): with mock.patch('course_discovery.apps.core.views.get_database_status', return_value=Status.UNAVAILABLE):
self._assert_health(503, Status.UNAVAILABLE, Status.UNAVAILABLE) self._assert_health(503, Status.UNAVAILABLE, Status.UNAVAILABLE)
def _assert_health(self, status_code, overall_status, database_status): def _assert_health(self, status_code, overall_status, database_status):
...@@ -44,7 +54,7 @@ class HealthTests(SiteMixin, TestCase): ...@@ -44,7 +54,7 @@ class HealthTests(SiteMixin, TestCase):
self.assertJSONEqual(force_text(response.content), expected_data) self.assertJSONEqual(force_text(response.content), expected_data)
class AutoAuthTests(SiteMixin, TestCase): class AutoAuthTests(TestCase):
""" Auto Auth view tests. """ """ Auto Auth view tests. """
AUTO_AUTH_PATH = reverse('auto_auth') AUTO_AUTH_PATH = reverse('auto_auth')
......
...@@ -15,6 +15,18 @@ logger = logging.getLogger(__name__) ...@@ -15,6 +15,18 @@ logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
def get_database_status():
"""Run a database query to see if the database is responsive."""
try:
cursor = connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
cursor.close()
return Status.OK
except DatabaseError:
return Status.UNAVAILABLE
@transaction.non_atomic_requests @transaction.non_atomic_requests
def health(_): def health(_):
"""Allows a load balancer to verify this service is up. """Allows a load balancer to verify this service is up.
...@@ -32,15 +44,7 @@ def health(_): ...@@ -32,15 +44,7 @@ def health(_):
>>> response.content >>> response.content
'{"overall_status": "OK", "detailed_status": {"database_status": "OK"}}' '{"overall_status": "OK", "detailed_status": {"database_status": "OK"}}'
""" """
database_status = get_database_status()
try:
cursor = connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
cursor.close()
database_status = Status.OK
except DatabaseError:
database_status = Status.UNAVAILABLE
overall_status = Status.OK if (database_status == Status.OK) else Status.UNAVAILABLE overall_status = Status.OK if (database_status == Status.OK) else Status.UNAVAILABLE
......
...@@ -11,7 +11,6 @@ from selenium.webdriver.support import expected_conditions as EC ...@@ -11,7 +11,6 @@ from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import Select from selenium.webdriver.support.ui import Select
from selenium.webdriver.support.wait import WebDriverWait from selenium.webdriver.support.wait import WebDriverWait
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import Partner from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
...@@ -24,7 +23,7 @@ from course_discovery.apps.course_metadata.tests import factories ...@@ -24,7 +23,7 @@ from course_discovery.apps.course_metadata.tests import factories
# pylint: disable=no-member # pylint: disable=no-member
@ddt.ddt @ddt.ddt
class AdminTests(SiteMixin, TestCase): class AdminTests(TestCase):
""" Tests Admin page.""" """ Tests Admin page."""
def setUp(self): def setUp(self):
...@@ -191,7 +190,7 @@ class AdminTests(SiteMixin, TestCase): ...@@ -191,7 +190,7 @@ class AdminTests(SiteMixin, TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase): class ProgramAdminFunctionalTests(LiveServerTestCase):
""" Functional Tests for Admin page.""" """ Functional Tests for Admin page."""
# Required for access to initial data loaded in migrations (e.g., LanguageTags). # Required for access to initial data loaded in migrations (e.g., LanguageTags).
serialized_rollback = True serialized_rollback = True
...@@ -225,6 +224,7 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase): ...@@ -225,6 +224,7 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
# ContentTypeManager uses a cache to speed up ContentType retrieval. This # ContentTypeManager uses a cache to speed up ContentType retrieval. This
# cache persists across tests. This is fine in the context of a regular # cache persists across tests. This is fine in the context of a regular
# TestCase which uses a transaction to reset the database between tests. # TestCase which uses a transaction to reset the database between tests.
...@@ -238,9 +238,6 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase): ...@@ -238,9 +238,6 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
# stale ContentType objects from being used. # stale ContentType objects from being used.
ContentType.objects.clear_cache() ContentType.objects.clear_cache()
self.site.domain = self.live_server_url.strip('http://')
self.site.save()
self.course_runs = factories.CourseRunFactory.create_batch(2) self.course_runs = factories.CourseRunFactory.create_batch(2)
self.courses = [course_run.course for course_run in self.course_runs] self.courses = [course_run.course for course_run in self.course_runs]
...@@ -357,7 +354,7 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase): ...@@ -357,7 +354,7 @@ class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
self.assertEqual(self.program.subtitle, subtitle) self.assertEqual(self.program.subtitle, subtitle)
class ProgramEligibilityFilterTests(SiteMixin, TestCase): class ProgramEligibilityFilterTests(TestCase):
""" Tests for Program Eligibility Filter class. """ """ Tests for Program Eligibility Filter class. """
parameter_name = 'eligible_for_one_click_purchase' parameter_name = 'eligible_for_one_click_purchase'
......
...@@ -5,7 +5,6 @@ import ddt ...@@ -5,7 +5,6 @@ import ddt
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import ( from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory
...@@ -17,7 +16,7 @@ from course_discovery.apps.publisher.tests import factories ...@@ -17,7 +16,7 @@ from course_discovery.apps.publisher.tests import factories
@ddt.ddt @ddt.ddt
class AutocompleteTests(SiteMixin, TestCase): class AutocompleteTests(TestCase):
""" Tests for autocomplete lookups.""" """ Tests for autocomplete lookups."""
def setUp(self): def setUp(self):
super(AutocompleteTests, self).setUp() super(AutocompleteTests, self).setUp()
...@@ -119,7 +118,7 @@ class AutocompleteTests(SiteMixin, TestCase): ...@@ -119,7 +118,7 @@ class AutocompleteTests(SiteMixin, TestCase):
@ddt.ddt @ddt.ddt
class AutoCompletePersonTests(SiteMixin, TestCase): class AutoCompletePersonTests(TestCase):
""" """
Tests for person autocomplete lookups Tests for person autocomplete lookups
""" """
......
...@@ -4,14 +4,13 @@ import ddt ...@@ -4,14 +4,13 @@ import ddt
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint: disable=no-member # pylint: disable=no-member
@ddt.ddt @ddt.ddt
class AutocompleteTests(SiteMixin, TestCase): class AutocompleteTests(TestCase):
""" Tests for autocomplete lookups.""" """ Tests for autocomplete lookups."""
def setUp(self): def setUp(self):
super(AutocompleteTests, self).setUp() super(AutocompleteTests, self).setUp()
......
...@@ -40,8 +40,7 @@ class CourseUserRoleSerializer(serializers.ModelSerializer): ...@@ -40,8 +40,7 @@ class CourseUserRoleSerializer(serializers.ModelSerializer):
former_user = instance.user former_user = instance.user
instance = super(CourseUserRoleSerializer, self).update(instance, validated_data) instance = super(CourseUserRoleSerializer, self).update(instance, validated_data)
if not instance.role == PublisherUserRole.CourseTeam: if not instance.role == PublisherUserRole.CourseTeam:
request = self.context['request'] send_change_role_assignment_email(instance, former_user)
send_change_role_assignment_email(instance, former_user, request.site)
return instance return instance
...@@ -105,7 +104,6 @@ class CourseRunSerializer(serializers.ModelSerializer): ...@@ -105,7 +104,6 @@ class CourseRunSerializer(serializers.ModelSerializer):
instance = super(CourseRunSerializer, self).update(instance, validated_data) instance = super(CourseRunSerializer, self).update(instance, validated_data)
preview_url = validated_data.get('preview_url') preview_url = validated_data.get('preview_url')
lms_course_id = validated_data.get('lms_course_id') lms_course_id = validated_data.get('lms_course_id')
request = self.context['request']
if preview_url: if preview_url:
# Change ownership to CourseTeam. # Change ownership to CourseTeam.
...@@ -113,10 +111,10 @@ class CourseRunSerializer(serializers.ModelSerializer): ...@@ -113,10 +111,10 @@ class CourseRunSerializer(serializers.ModelSerializer):
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
if preview_url: if preview_url:
send_email_preview_page_is_available(instance, site=request.site) send_email_preview_page_is_available(instance)
elif lms_course_id: elif lms_course_id:
send_email_for_studio_instance_created(instance, site=request.site) send_email_for_studio_instance_created(instance)
return instance return instance
...@@ -169,7 +167,7 @@ class CourseStateSerializer(serializers.ModelSerializer): ...@@ -169,7 +167,7 @@ class CourseStateSerializer(serializers.ModelSerializer):
state = validated_data.get('name') state = validated_data.get('name')
request = self.context.get('request') request = self.context.get('request')
try: try:
instance.change_state(state=state, user=request.user, site=request.site) instance.change_state(state=state, user=request.user)
except TransitionNotAllowed: except TransitionNotAllowed:
# pylint: disable=no-member # pylint: disable=no-member
raise serializers.ValidationError( raise serializers.ValidationError(
...@@ -206,7 +204,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer): ...@@ -206,7 +204,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
if state: if state:
try: try:
instance.change_state(state=state, user=request.user, site=request.site) instance.change_state(state=state, user=request.user)
except TransitionNotAllowed: except TransitionNotAllowed:
# pylint: disable=no-member # pylint: disable=no-member
raise serializers.ValidationError( raise serializers.ValidationError(
...@@ -225,6 +223,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer): ...@@ -225,6 +223,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
instance.save() instance.save()
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
send_email_preview_accepted(instance.course_run, request.site) send_email_preview_accepted(instance.course_run)
return instance return instance
...@@ -5,7 +5,6 @@ from django.test import RequestFactory, TestCase ...@@ -5,7 +5,6 @@ from django.test import RequestFactory, TestCase
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
...@@ -21,7 +20,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour ...@@ -21,7 +20,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour
OrganizationExtensionFactory, SeatFactory) OrganizationExtensionFactory, SeatFactory)
class CourseUserRoleSerializerTests(SiteMixin, TestCase): class CourseUserRoleSerializerTests(TestCase):
serializer_class = CourseUserRoleSerializer serializer_class = CourseUserRoleSerializer
def setUp(self): def setUp(self):
...@@ -29,7 +28,6 @@ class CourseUserRoleSerializerTests(SiteMixin, TestCase): ...@@ -29,7 +28,6 @@ class CourseUserRoleSerializerTests(SiteMixin, TestCase):
self.request = RequestFactory() self.request = RequestFactory()
self.course_user_role = CourseUserRoleFactory(role=PublisherUserRole.MarketingReviewer) self.course_user_role = CourseUserRoleFactory(role=PublisherUserRole.MarketingReviewer)
self.request.user = self.course_user_role.user self.request.user = self.course_user_role.user
self.request.site = self.site
def get_expected_data(self): def get_expected_data(self):
""" Helper method which will return expected serialize data. """ """ Helper method which will return expected serialize data. """
...@@ -140,7 +138,7 @@ class CourseRunSerializerTests(TestCase): ...@@ -140,7 +138,7 @@ class CourseRunSerializerTests(TestCase):
""" """
self.course_run.preview_url = '' self.course_run.preview_url = ''
self.course_run.save() self.course_run.save()
serializer = self.serializer_class(self.course_run, context={'request': self.request}) serializer = self.serializer_class(self.course_run)
serializer.update(self.course_run, {'preview_url': 'https://example.com/abc/course'}) serializer.update(self.course_run, {'preview_url': 'https://example.com/abc/course'})
self.assertEqual(self.course_state.owner_role, PublisherUserRole.CourseTeam) self.assertEqual(self.course_state.owner_role, PublisherUserRole.CourseTeam)
...@@ -248,12 +246,13 @@ class CourseRevisionSerializerTests(TestCase): ...@@ -248,12 +246,13 @@ class CourseRevisionSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
class CourseStateSerializerTests(SiteMixin, TestCase): class CourseStateSerializerTests(TestCase):
serializer_class = CourseStateSerializer serializer_class = CourseStateSerializer
def setUp(self): def setUp(self):
super(CourseStateSerializerTests, self).setUp() super(CourseStateSerializerTests, self).setUp()
self.course_state = CourseStateFactory(name=CourseStateChoices.Draft) self.course_state = CourseStateFactory(name=CourseStateChoices.Draft)
self.request = RequestFactory()
self.user = UserFactory() self.user = UserFactory()
self.request.user = self.user self.request.user = self.user
...@@ -290,13 +289,14 @@ class CourseStateSerializerTests(SiteMixin, TestCase): ...@@ -290,13 +289,14 @@ class CourseStateSerializerTests(SiteMixin, TestCase):
serializer.update(self.course_state, data) serializer.update(self.course_state, data)
class CourseRunStateSerializerTests(SiteMixin, TestCase): class CourseRunStateSerializerTests(TestCase):
serializer_class = CourseRunStateSerializer serializer_class = CourseRunStateSerializer
def setUp(self): def setUp(self):
super(CourseRunStateSerializerTests, self).setUp() super(CourseRunStateSerializerTests, self).setUp()
self.run_state = CourseRunStateFactory(name=CourseRunStateChoices.Draft) self.run_state = CourseRunStateFactory(name=CourseRunStateChoices.Draft)
self.course_run = self.run_state.course_run self.course_run = self.run_state.course_run
self.request = RequestFactory()
self.user = UserFactory() self.user = UserFactory()
self.request.user = self.user self.request.user = self.user
CourseStateFactory(name=CourseStateChoices.Approved, course=self.course_run.course) CourseStateFactory(name=CourseStateChoices.Approved, course=self.course_run.course)
......
...@@ -4,6 +4,7 @@ from urllib.parse import quote ...@@ -4,6 +4,7 @@ from urllib.parse import quote
import ddt import ddt
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.contrib.sites.models import Site
from django.core import mail from django.core import mail
from django.db import IntegrityError from django.db import IntegrityError
from django.test import TestCase from django.test import TestCase
...@@ -13,7 +14,6 @@ from mock import mock, patch ...@@ -13,7 +14,6 @@ from mock import mock, patch
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
...@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories ...@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories
@ddt.ddt @ddt.ddt
class CourseRoleAssignmentViewTests(SiteMixin, TestCase): class CourseRoleAssignmentViewTests(TestCase):
def setUp(self): def setUp(self):
super(CourseRoleAssignmentViewTests, self).setUp() super(CourseRoleAssignmentViewTests, self).setUp()
...@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(SiteMixin, TestCase): ...@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(SiteMixin, TestCase):
self.assertEqual(len(mail.outbox), 1) self.assertEqual(len(mail.outbox), 1)
class OrganizationGroupUserViewTests(SiteMixin, TestCase): class OrganizationGroupUserViewTests(TestCase):
def setUp(self): def setUp(self):
super(OrganizationGroupUserViewTests, self).setUp() super(OrganizationGroupUserViewTests, self).setUp()
...@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(SiteMixin, TestCase): ...@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(SiteMixin, TestCase):
) )
class UpdateCourseRunViewTests(SiteMixin, TestCase): class UpdateCourseRunViewTests(TestCase):
def setUp(self): def setUp(self):
super(UpdateCourseRunViewTests, self).setUp() super(UpdateCourseRunViewTests, self).setUp()
...@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(SiteMixin, TestCase): ...@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body) self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
def test_update_preview_url(self): def test_update_preview_url(self):
...@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(SiteMixin, TestCase): ...@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(SiteMixin, TestCase):
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
class CourseRevisionDetailViewTests(SiteMixin, TestCase): class CourseRevisionDetailViewTests(TestCase):
def setUp(self): def setUp(self):
super(CourseRevisionDetailViewTests, self).setUp() super(CourseRevisionDetailViewTests, self).setUp()
...@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(SiteMixin, TestCase): ...@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(SiteMixin, TestCase):
return self.client.get(path=course_revision_path) return self.client.get(path=course_revision_path)
class ChangeCourseStateViewTests(SiteMixin, TestCase): class ChangeCourseStateViewTests(TestCase):
def setUp(self): def setUp(self):
super(ChangeCourseStateViewTests, self).setUp() super(ChangeCourseStateViewTests, self).setUp()
...@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(SiteMixin, TestCase): ...@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
object_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id}) object_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
def test_change_course_state_with_error(self): def test_change_course_state_with_error(self):
...@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(SiteMixin, TestCase): ...@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(SiteMixin, TestCase):
self._assert_email_sent(course_team_user, subject) self._assert_email_sent(course_team_user, subject)
class ChangeCourseRunStateViewTests(SiteMixin, TestCase): class ChangeCourseRunStateViewTests(TestCase):
def setUp(self): def setUp(self):
super(ChangeCourseRunStateViewTests, self).setUp() super(ChangeCourseRunStateViewTests, self).setUp()
...@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(SiteMixin, TestCase): ...@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(SiteMixin, TestCase):
self.assertIn('has been published', mail.outbox[0].body.strip()) self.assertIn('has been published', mail.outbox[0].body.strip())
class RevertCourseByRevisionTests(SiteMixin, TestCase): class RevertCourseByRevisionTests(TestCase):
def setUp(self): def setUp(self):
super(RevertCourseByRevisionTests, self).setUp() super(RevertCourseByRevisionTests, self).setUp()
...@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(SiteMixin, TestCase): ...@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(SiteMixin, TestCase):
return self.client.put(path=course_revision_path) return self.client.put(path=course_revision_path)
class CoursesAutoCompleteTests(SiteMixin, TestCase): class CoursesAutoCompleteTests(TestCase):
""" Tests for course autocomplete.""" """ Tests for course autocomplete."""
def setUp(self): def setUp(self):
...@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(SiteMixin, TestCase): ...@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(SiteMixin, TestCase):
self.assertEqual(len(data['results']), expected_length) self.assertEqual(len(data['results']), expected_length)
class AcceptAllByRevisionTests(SiteMixin, TestCase): class AcceptAllByRevisionTests(TestCase):
def setUp(self): def setUp(self):
super(AcceptAllByRevisionTests, self).setUp() super(AcceptAllByRevisionTests, self).setUp()
......
import logging import logging
from django.conf import settings from django.conf import settings
from django.contrib.sites.models import Site
from django.core.mail.message import EmailMultiAlternatives from django.core.mail.message import EmailMultiAlternatives
from django.template.loader import get_template from django.template.loader import get_template
from django.urls import reverse from django.urls import reverse
...@@ -15,12 +16,11 @@ from course_discovery.apps.publisher.utils import is_email_notification_enabled ...@@ -15,12 +16,11 @@ from course_discovery.apps.publisher.utils import is_email_notification_enabled
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def send_email_for_studio_instance_created(course_run, site): def send_email_for_studio_instance_created(course_run):
""" Send an email to course team on studio instance creation. """ Send an email to course team on studio instance creation.
Arguments: Arguments:
course_run (CourseRun): CourseRun object course_run (CourseRun): CourseRun object
site (Site): Current site
""" """
try: try:
course_key = CourseKey.from_string(course_run.lms_course_id) course_key = CourseKey.from_string(course_run.lms_course_id)
...@@ -39,7 +39,7 @@ def send_email_for_studio_instance_created(course_run, site): ...@@ -39,7 +39,7 @@ def send_email_for_studio_instance_created(course_run, site):
context = { context = {
'course_run': course_run, 'course_run': course_run,
'course_run_page_url': 'https://{host}{path}'.format( 'course_run_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=object_path host=Site.objects.get_current().domain.strip('/'), path=object_path
), ),
'course_name': course_run.course.title, 'course_name': course_run.course.title,
'from_address': from_address, 'from_address': from_address,
...@@ -65,13 +65,12 @@ def send_email_for_studio_instance_created(course_run, site): ...@@ -65,13 +65,12 @@ def send_email_for_studio_instance_created(course_run, site):
raise Exception(error_message) raise Exception(error_message)
def send_email_for_course_creation(course, course_run, site): def send_email_for_course_creation(course, course_run):
""" Send the emails for a course creation. """ Send the emails for a course creation.
Arguments: Arguments:
course (Course): Course object course (Course): Course object
course_run (CourseRun): CourseRun object course_run (CourseRun): CourseRun object
site (Site): Current site
""" """
txt_template = 'publisher/email/course_created.txt' txt_template = 'publisher/email/course_created.txt'
html_template = 'publisher/email/course_created.html' html_template = 'publisher/email/course_created.html'
...@@ -92,7 +91,7 @@ def send_email_for_course_creation(course, course_run, site): ...@@ -92,7 +91,7 @@ def send_email_for_course_creation(course, course_run, site):
'course_team_name': course_team.get_full_name(), 'course_team_name': course_team.get_full_name(),
'project_coordinator_name': project_coordinator.get_full_name(), 'project_coordinator_name': project_coordinator.get_full_name(),
'dashboard_url': 'https://{host}{path}'.format( 'dashboard_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=reverse('publisher:publisher_dashboard') host=Site.objects.get_current().domain.strip('/'), path=reverse('publisher:publisher_dashboard')
), ),
'from_address': from_address, 'from_address': from_address,
'contact_us_email': project_coordinator.email 'contact_us_email': project_coordinator.email
...@@ -114,13 +113,12 @@ def send_email_for_course_creation(course, course_run, site): ...@@ -114,13 +113,12 @@ def send_email_for_course_creation(course, course_run, site):
) )
def send_email_for_send_for_review(course, user, site): def send_email_for_send_for_review(course, user):
""" Send email when course is submitted for review. """ Send email when course is submitted for review.
Arguments: Arguments:
course (Object): Course object course (Object): Course object
user (Object): User object user (Object): User object
site (Site): Current site
""" """
txt_template = 'publisher/email/course/send_for_review.txt' txt_template = 'publisher/email/course/send_for_review.txt'
html_template = 'publisher/email/course/send_for_review.html' html_template = 'publisher/email/course/send_for_review.html'
...@@ -137,22 +135,21 @@ def send_email_for_send_for_review(course, user, site): ...@@ -137,22 +135,21 @@ def send_email_for_send_for_review(course, user, site):
'course_name': course.title, 'course_name': course.title,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team', 'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
) )
} }
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site) send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications send for review of course %s', course.id) logger.exception('Failed to send email notifications send for review of course %s', course.id)
def send_email_for_mark_as_reviewed(course, user, site): def send_email_for_mark_as_reviewed(course, user):
""" Send email when course is marked as reviewed. """ Send email when course is marked as reviewed.
Arguments: Arguments:
course (Object): Course object course (Object): Course object
user (Object): User object user (Object): User object
site (Site): Current site
""" """
txt_template = 'publisher/email/course/mark_as_reviewed.txt' txt_template = 'publisher/email/course/mark_as_reviewed.txt'
html_template = 'publisher/email/course/mark_as_reviewed.html' html_template = 'publisher/email/course/mark_as_reviewed.html'
...@@ -169,16 +166,16 @@ def send_email_for_mark_as_reviewed(course, user, site): ...@@ -169,16 +166,16 @@ def send_email_for_mark_as_reviewed(course, user, site):
'course_name': course.title, 'course_name': course.title,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team', 'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'marketing team',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
) )
} }
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site) send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications mark as reviewed of course %s', course.id) logger.exception('Failed to send email notifications mark as reviewed of course %s', course.id)
def send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site): def send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user):
""" Send email for course workflow state change. """ Send email for course workflow state change.
Arguments: Arguments:
...@@ -189,7 +186,6 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat ...@@ -189,7 +186,6 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
html_template: (String): Email html template path html_template: (String): Email html template path
context: (Dict): Email template context context: (Dict): Email template context
recipient_user: (Object): User object recipient_user: (Object): User object
site (Site): Current site
""" """
if is_email_notification_enabled(recipient_user): if is_email_notification_enabled(recipient_user):
...@@ -206,7 +202,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat ...@@ -206,7 +202,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
'org_name': course.organizations.all().first().name, 'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email if project_coordinator else '', 'contact_us_email': project_coordinator.email if project_coordinator else '',
'course_page_url': 'https://{host}{path}'.format( 'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path host=Site.objects.get_current().domain.strip('/'), path=course_page_path
) )
} }
) )
...@@ -223,13 +219,12 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat ...@@ -223,13 +219,12 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
email_msg.send() email_msg.send()
def send_email_for_send_for_review_course_run(course_run, user, site): def send_email_for_send_for_review_course_run(course_run, user):
""" Send email when course-run is submitted for review. """ Send email when course-run is submitted for review.
Arguments: Arguments:
course-run (Object): CourseRun object course-run (Object): CourseRun object
user (Object): User object user (Object): User object
site (Site): Current site
""" """
course = course_run.course course = course_run.course
course_key = CourseKey.from_string(course_run.lms_course_id) course_key = CourseKey.from_string(course_run.lms_course_id)
...@@ -251,23 +246,22 @@ def send_email_for_send_for_review_course_run(course_run, user, site): ...@@ -251,23 +246,22 @@ def send_email_for_send_for_review_course_run(course_run, user, site):
'run_number': course_key.run, 'run_number': course_key.run,
'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'project coordinators', 'sender_team': 'course team' if user_role.role == PublisherUserRole.CourseTeam else 'project coordinators',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
), ),
'studio_url': course_run.studio_url 'studio_url': course_run.studio_url
} }
send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user, site) send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications send for review of course-run %s', course_run.id) logger.exception('Failed to send email notifications send for review of course-run %s', course_run.id)
def send_email_for_mark_as_reviewed_course_run(course_run, user, site): def send_email_for_mark_as_reviewed_course_run(course_run, user):
""" Send email when course-run is marked as reviewed. """ Send email when course-run is marked as reviewed.
Arguments: Arguments:
course_run (Object): CourseRun object course_run (Object): CourseRun object
user (Object): User object user (Object): User object
site (Site): Current site
""" """
txt_template = 'publisher/email/course_run/mark_as_reviewed.txt' txt_template = 'publisher/email/course_run/mark_as_reviewed.txt'
html_template = 'publisher/email/course_run/mark_as_reviewed.html' html_template = 'publisher/email/course_run/mark_as_reviewed.html'
...@@ -290,24 +284,21 @@ def send_email_for_mark_as_reviewed_course_run(course_run, user, site): ...@@ -290,24 +284,21 @@ def send_email_for_mark_as_reviewed_course_run(course_run, user, site):
'run_number': course_key.run, 'run_number': course_key.run,
'sender_team': 'course team', 'sender_team': 'course team',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
) )
} }
send_course_workflow_email( send_course_workflow_email(course, user, subject, txt_template, html_template, context, recipient_user)
course, user, subject, txt_template, html_template, context, recipient_user, site
)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id) logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id)
def send_email_to_publisher(course_run, user, site): def send_email_to_publisher(course_run, user):
""" Send email to publisher when course-run is marked as reviewed. """ Send email to publisher when course-run is marked as reviewed.
Arguments: Arguments:
course_run (Object): CourseRun object course_run (Object): CourseRun object
user (Object): User object user (Object): User object
site (Site): Current site
""" """
txt_template = 'publisher/email/course_run/mark_as_reviewed.txt' txt_template = 'publisher/email/course_run/mark_as_reviewed.txt'
html_template = 'publisher/email/course_run/mark_as_reviewed.html' html_template = 'publisher/email/course_run/mark_as_reviewed.html'
...@@ -339,7 +330,7 @@ def send_email_to_publisher(course_run, user, site): ...@@ -339,7 +330,7 @@ def send_email_to_publisher(course_run, user, site):
'sender_team': sender_team, 'sender_team': sender_team,
'contact_us_email': project_coordinator.email if project_coordinator else '', 'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
) )
} }
...@@ -358,12 +349,11 @@ def send_email_to_publisher(course_run, user, site): ...@@ -358,12 +349,11 @@ def send_email_to_publisher(course_run, user, site):
logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id) logger.exception('Failed to send email notifications for mark as reviewed of course-run %s', course_run.id)
def send_email_preview_accepted(course_run, site): def send_email_preview_accepted(course_run):
""" Send email for preview approved to publisher and project coordinator. """ Send email for preview approved to publisher and project coordinator.
Arguments: Arguments:
course_run (Object): CourseRun object course_run (Object): CourseRun object
site (Site): Current site
""" """
txt_template = 'publisher/email/course_run/preview_accepted.txt' txt_template = 'publisher/email/course_run/preview_accepted.txt'
html_template = 'publisher/email/course_run/preview_accepted.html' html_template = 'publisher/email/course_run/preview_accepted.html'
...@@ -392,10 +382,10 @@ def send_email_preview_accepted(course_run, site): ...@@ -392,10 +382,10 @@ def send_email_preview_accepted(course_run, site):
'org_name': course.organizations.all().first().name, 'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email if project_coordinator else '', 'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
), ),
'course_page_url': 'https://{host}{path}'.format( 'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path host=Site.objects.get_current().domain.strip('/'), path=course_page_path
) )
} }
template = get_template(txt_template) template = get_template(txt_template)
...@@ -416,12 +406,11 @@ def send_email_preview_accepted(course_run, site): ...@@ -416,12 +406,11 @@ def send_email_preview_accepted(course_run, site):
raise Exception(message) raise Exception(message)
def send_email_preview_page_is_available(course_run, site): def send_email_preview_page_is_available(course_run):
""" Send email for course preview available to course team. """ Send email for course preview available to course team.
Arguments: Arguments:
course_run (Object): CourseRun object course_run (Object): CourseRun object
site (Site): Current site
""" """
txt_template = 'publisher/email/course_run/preview_available.txt' txt_template = 'publisher/email/course_run/preview_available.txt'
html_template = 'publisher/email/course_run/preview_available.html' html_template = 'publisher/email/course_run/preview_available.html'
...@@ -447,10 +436,10 @@ def send_email_preview_page_is_available(course_run, site): ...@@ -447,10 +436,10 @@ def send_email_preview_page_is_available(course_run, site):
'preview_link': course_run.preview_url, 'preview_link': course_run.preview_url,
'contact_us_email': project_coordinator.email if project_coordinator else '', 'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
), ),
'course_page_url': 'https://{host}{path}'.format( 'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path host=Site.objects.get_current().domain.strip('/'), path=course_page_path
), ),
'platform_name': settings.PLATFORM_NAME 'platform_name': settings.PLATFORM_NAME
} }
...@@ -473,12 +462,11 @@ def send_email_preview_page_is_available(course_run, site): ...@@ -473,12 +462,11 @@ def send_email_preview_page_is_available(course_run, site):
raise Exception(error_message) raise Exception(error_message)
def send_course_run_published_email(course_run, site): def send_course_run_published_email(course_run):
""" Send email when course run is published by publisher. """ Send email when course run is published by publisher.
Arguments: Arguments:
course_run (Object): CourseRun object course_run (Object): CourseRun object
site (Site): Current site
""" """
txt_template = 'publisher/email/course_run/published.txt' txt_template = 'publisher/email/course_run/published.txt'
html_template = 'publisher/email/course_run/published.html' html_template = 'publisher/email/course_run/published.html'
...@@ -504,10 +492,10 @@ def send_course_run_published_email(course_run, site): ...@@ -504,10 +492,10 @@ def send_course_run_published_email(course_run, site):
'recipient_name': course_team_user.get_full_name() or course_team_user.username, 'recipient_name': course_team_user.get_full_name() or course_team_user.username,
'contact_us_email': project_coordinator.email if project_coordinator else '', 'contact_us_email': project_coordinator.email if project_coordinator else '',
'page_url': 'https://{host}{path}'.format( 'page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
), ),
'course_page_url': 'https://{host}{path}'.format( 'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path host=Site.objects.get_current().domain.strip('/'), path=course_page_path
), ),
'platform_name': settings.PLATFORM_NAME, 'platform_name': settings.PLATFORM_NAME,
} }
...@@ -530,13 +518,12 @@ def send_course_run_published_email(course_run, site): ...@@ -530,13 +518,12 @@ def send_course_run_published_email(course_run, site):
raise Exception(error_message) raise Exception(error_message)
def send_change_role_assignment_email(course_role, former_user, site): def send_change_role_assignment_email(course_role, former_user):
""" Send email for role assignment changed. """ Send email for role assignment changed.
Arguments: Arguments:
course_role (Object): CourseUserRole object course_role (Object): CourseUserRole object
former_user (Object): User object former_user (Object): User object
site (Site): Current site
""" """
txt_template = 'publisher/email/role_assignment_changed.txt' txt_template = 'publisher/email/role_assignment_changed.txt'
html_template = 'publisher/email/role_assignment_changed.html' html_template = 'publisher/email/role_assignment_changed.html'
...@@ -562,7 +549,7 @@ def send_change_role_assignment_email(course_role, former_user, site): ...@@ -562,7 +549,7 @@ def send_change_role_assignment_email(course_role, former_user, site):
'current_user_name': course_role.user.get_full_name() or course_role.user.username, 'current_user_name': course_role.user.get_full_name() or course_role.user.username,
'contact_us_email': project_coordinator.email if project_coordinator else '', 'contact_us_email': project_coordinator.email if project_coordinator else '',
'course_url': 'https://{host}{path}'.format( 'course_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=page_path host=Site.objects.get_current().domain.strip('/'), path=page_path
), ),
'platform_name': settings.PLATFORM_NAME, 'platform_name': settings.PLATFORM_NAME,
} }
...@@ -585,12 +572,11 @@ def send_change_role_assignment_email(course_role, former_user, site): ...@@ -585,12 +572,11 @@ def send_change_role_assignment_email(course_role, former_user, site):
raise Exception(error_message) raise Exception(error_message)
def send_email_for_seo_review(course, site): def send_email_for_seo_review(course):
""" Send email when course is submitted for seo review. """ Send email when course is submitted for seo review.
Arguments: Arguments:
course (Object): Course object course (Object): Course object
site (Site): Current site
""" """
txt_template = 'publisher/email/course/seo_review.txt' txt_template = 'publisher/email/course/seo_review.txt'
html_template = 'publisher/email/course/seo_review.html' html_template = 'publisher/email/course/seo_review.html'
...@@ -611,7 +597,7 @@ def send_email_for_seo_review(course, site): ...@@ -611,7 +597,7 @@ def send_email_for_seo_review(course, site):
'org_name': course.organizations.all().first().name, 'org_name': course.organizations.all().first().name,
'contact_us_email': project_coordinator.email, 'contact_us_email': project_coordinator.email,
'course_page_url': 'https://{host}{path}'.format( 'course_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=course_page_path host=Site.objects.get_current().domain.strip('/'), path=course_page_path
) )
} }
...@@ -629,12 +615,11 @@ def send_email_for_seo_review(course, site): ...@@ -629,12 +615,11 @@ def send_email_for_seo_review(course, site):
logger.exception('Failed to send email notifications for legal review requested of course %s', course.id) logger.exception('Failed to send email notifications for legal review requested of course %s', course.id)
def send_email_for_published_course_run_editing(course_run, site): def send_email_for_published_course_run_editing(course_run):
""" Send email when published course-run is edited. """ Send email when published course-run is edited.
Arguments: Arguments:
course-run (Object): Course Run object course-run (Object): Course Run object
site (Site): Current site
""" """
try: try:
course = course_run.course course = course_run.course
...@@ -659,7 +644,7 @@ def send_email_for_published_course_run_editing(course_run, site): ...@@ -659,7 +644,7 @@ def send_email_for_published_course_run_editing(course_run, site):
'recipient_name': publisher_user.get_full_name() or publisher_user.username, 'recipient_name': publisher_user.get_full_name() or publisher_user.username,
'contact_us_email': course.project_coordinator.email, 'contact_us_email': course.project_coordinator.email,
'course_run_page_url': 'https://{host}{path}'.format( 'course_run_page_url': 'https://{host}{path}'.format(
host=site.domain.strip('/'), path=object_path host=Site.objects.get_current().domain.strip('/'), path=object_path
), ),
'course_run_number': course_key.run, 'course_run_number': course_key.run,
} }
......
...@@ -604,7 +604,7 @@ class CourseState(TimeStampedModel, ChangedByMixin): ...@@ -604,7 +604,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
# TODO: send email etc. # TODO: send email etc.
pass pass
def change_state(self, state, user, site=None): def change_state(self, state, user):
""" """
Change course workflow state and ownership also send emails if required. Change course workflow state and ownership also send emails if required.
""" """
...@@ -619,12 +619,12 @@ class CourseState(TimeStampedModel, ChangedByMixin): ...@@ -619,12 +619,12 @@ class CourseState(TimeStampedModel, ChangedByMixin):
elif user_role.role == PublisherUserRole.CourseTeam: elif user_role.role == PublisherUserRole.CourseTeam:
self.change_owner_role(PublisherUserRole.MarketingReviewer) self.change_owner_role(PublisherUserRole.MarketingReviewer)
if is_notifications_enabled: if is_notifications_enabled:
emails.send_email_for_seo_review(self.course, site) emails.send_email_for_seo_review(self.course)
self.review() self.review()
if is_notifications_enabled: if is_notifications_enabled:
emails.send_email_for_send_for_review(self.course, user, site) emails.send_email_for_send_for_review(self.course, user)
elif state == CourseStateChoices.Approved: elif state == CourseStateChoices.Approved:
user_role = self.course.course_user_roles.get(user=user) user_role = self.course.course_user_roles.get(user=user)
...@@ -633,7 +633,7 @@ class CourseState(TimeStampedModel, ChangedByMixin): ...@@ -633,7 +633,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
self.approved() self.approved()
if is_notifications_enabled: if is_notifications_enabled:
emails.send_email_for_mark_as_reviewed(self.course, user, site) emails.send_email_for_mark_as_reviewed(self.course, user)
self.save() self.save()
...@@ -731,10 +731,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin): ...@@ -731,10 +731,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
pass pass
@transition(field=name, source=CourseRunStateChoices.Approved, target=CourseRunStateChoices.Published) @transition(field=name, source=CourseRunStateChoices.Approved, target=CourseRunStateChoices.Published)
def published(self, site): def published(self):
emails.send_course_run_published_email(self.course_run, site) emails.send_course_run_published_email(self.course_run)
def change_state(self, state, user, site=None): def change_state(self, state, user):
""" """
Change course run workflow state and ownership also send emails if required. Change course run workflow state and ownership also send emails if required.
""" """
...@@ -750,7 +750,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin): ...@@ -750,7 +750,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.review() self.review()
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_send_for_review_course_run(self.course_run, user, site) emails.send_email_for_send_for_review_course_run(self.course_run, user)
elif state == CourseRunStateChoices.Approved: elif state == CourseRunStateChoices.Approved:
user_role = self.course_run.course.course_user_roles.get(user=user) user_role = self.course_run.course.course_user_roles.get(user=user)
...@@ -759,11 +759,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin): ...@@ -759,11 +759,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.approved() self.approved()
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user, site) emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user)
emails.send_email_to_publisher(self.course_run, user, site) emails.send_email_to_publisher(self.course_run, user)
elif state == CourseRunStateChoices.Published: elif state == CourseRunStateChoices.Published:
self.published(site) self.published()
self.save() self.save()
......
...@@ -4,7 +4,6 @@ from django.test import TestCase ...@@ -4,7 +4,6 @@ from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from guardian.shortcuts import get_group_perms from guardian.shortcuts import get_group_perms
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory
from course_discovery.apps.publisher.choices import PublisherUserRole from course_discovery.apps.publisher.choices import PublisherUserRole
...@@ -19,7 +18,7 @@ USER_PASSWORD = 'password' ...@@ -19,7 +18,7 @@ USER_PASSWORD = 'password'
# pylint: disable=no-member # pylint: disable=no-member
class AdminTests(SiteMixin, TestCase): class AdminTests(TestCase):
""" Tests Admin page.""" """ Tests Admin page."""
def setUp(self): def setUp(self):
...@@ -82,7 +81,7 @@ class AdminTests(SiteMixin, TestCase): ...@@ -82,7 +81,7 @@ class AdminTests(SiteMixin, TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
class OrganizationExtensionAdminTests(SiteMixin, TestCase): class OrganizationExtensionAdminTests(TestCase):
""" Tests for OrganizationExtensionAdmin.""" """ Tests for OrganizationExtensionAdmin."""
def setUp(self): def setUp(self):
...@@ -135,7 +134,7 @@ class OrganizationExtensionAdminTests(SiteMixin, TestCase): ...@@ -135,7 +134,7 @@ class OrganizationExtensionAdminTests(SiteMixin, TestCase):
@ddt.ddt @ddt.ddt
class OrganizationUserRoleAdminTests(SiteMixin, TestCase): class OrganizationUserRoleAdminTests(TestCase):
""" Tests for OrganizationUserRoleAdmin.""" """ Tests for OrganizationUserRoleAdmin."""
def setUp(self): def setUp(self):
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
import mock import mock
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.contrib.sites.models import Site
from django.core import mail from django.core import mail
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
...@@ -21,7 +21,7 @@ from course_discovery.apps.publisher.tests import factories ...@@ -21,7 +21,7 @@ from course_discovery.apps.publisher.tests import factories
from course_discovery.apps.publisher.tests.factories import UserAttributeFactory from course_discovery.apps.publisher.tests.factories import UserAttributeFactory
class StudioInstanceCreatedEmailTests(SiteMixin, TestCase): class StudioInstanceCreatedEmailTests(TestCase):
""" """
Tests for the studio instance created email functionality. Tests for the studio instance created email functionality.
""" """
...@@ -50,14 +50,14 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase): ...@@ -50,14 +50,14 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
""" Verify that emails failure raise exception.""" """ Verify that emails failure raise exception."""
with self.assertRaises(Exception) as ex: with self.assertRaises(Exception) as ex:
emails.send_email_for_studio_instance_created(self.course_run, self.site) emails.send_email_for_studio_instance_created(self.course_run)
error_message = 'Failed to send email notifications for course_run [{}]'.format(self.course_run.id) error_message = 'Failed to send email notifications for course_run [{}]'.format(self.course_run.id)
self.assertEqual(ex.message, error_message) self.assertEqual(ex.message, error_message)
def test_email_sent_successfully(self): def test_email_sent_successfully(self):
""" Verify that emails sent successfully for studio instance created.""" """ Verify that emails sent successfully for studio instance created."""
emails.send_email_for_studio_instance_created(self.course_run, self.site) emails.send_email_for_studio_instance_created(self.course_run)
course_key = CourseKey.from_string(self.course_run.lms_course_id) course_key = CourseKey.from_string(self.course_run.lms_course_id)
self.assert_email_sent( self.assert_email_sent(
reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id}), reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id}),
...@@ -76,7 +76,7 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase): ...@@ -76,7 +76,7 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body) self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
self.assertIn('Enter course run content in Studio.', body) self.assertIn('Enter course run content in Studio.', body)
self.assertIn('Thanks', body) self.assertIn('Thanks', body)
...@@ -89,7 +89,7 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase): ...@@ -89,7 +89,7 @@ class StudioInstanceCreatedEmailTests(SiteMixin, TestCase):
) )
class CourseCreatedEmailTests(SiteMixin, TestCase): class CourseCreatedEmailTests(TestCase):
""" Tests for the new course created email functionality. """ """ Tests for the new course created email functionality. """
def setUp(self): def setUp(self):
...@@ -116,7 +116,7 @@ class CourseCreatedEmailTests(SiteMixin, TestCase): ...@@ -116,7 +116,7 @@ class CourseCreatedEmailTests(SiteMixin, TestCase):
""" Verify that emails failure logs error message.""" """ Verify that emails failure logs error message."""
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site) emails.send_email_for_course_creation(self.course_run.course, self.course_run)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -130,7 +130,7 @@ class CourseCreatedEmailTests(SiteMixin, TestCase): ...@@ -130,7 +130,7 @@ class CourseCreatedEmailTests(SiteMixin, TestCase):
def test_email_sent_successfully(self): def test_email_sent_successfully(self):
""" Verify that studio instance request email sent successfully.""" """ Verify that studio instance request email sent successfully."""
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site) emails.send_email_for_course_creation(self.course_run.course, self.course_run)
subject = 'Studio URL requested: {title}'.format(title=self.course_run.course.title) subject = 'Studio URL requested: {title}'.format(title=self.course_run.course.title)
self.assert_email_sent(subject) self.assert_email_sent(subject)
...@@ -151,12 +151,12 @@ class CourseCreatedEmailTests(SiteMixin, TestCase): ...@@ -151,12 +151,12 @@ class CourseCreatedEmailTests(SiteMixin, TestCase):
user_attribute = UserAttributes.objects.get(user=self.user) user_attribute = UserAttributes.objects.get(user=self.user)
user_attribute.enable_email_notification = False user_attribute.enable_email_notification = False
user_attribute.save() user_attribute.save()
emails.send_email_for_course_creation(self.course_run.course, self.course_run, self.site) emails.send_email_for_course_creation(self.course_run.course, self.course_run)
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
class SendForReviewEmailTests(SiteMixin, TestCase): class SendForReviewEmailTests(TestCase):
""" Tests for the send for review email functionality. """ """ Tests for the send for review email functionality. """
def setUp(self): def setUp(self):
...@@ -168,7 +168,7 @@ class SendForReviewEmailTests(SiteMixin, TestCase): ...@@ -168,7 +168,7 @@ class SendForReviewEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message.""" """ Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_for_send_for_review(self.course_state.course, self.user, self.site) emails.send_email_for_send_for_review(self.course_state.course, self.user)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -180,7 +180,7 @@ class SendForReviewEmailTests(SiteMixin, TestCase): ...@@ -180,7 +180,7 @@ class SendForReviewEmailTests(SiteMixin, TestCase):
) )
class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase): class CourseMarkAsReviewedEmailTests(TestCase):
""" Tests for the mark as reviewed email functionality. """ """ Tests for the mark as reviewed email functionality. """
def setUp(self): def setUp(self):
...@@ -192,7 +192,7 @@ class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase): ...@@ -192,7 +192,7 @@ class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message.""" """ Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_for_mark_as_reviewed(self.course_state.course, self.user, self.site) emails.send_email_for_mark_as_reviewed(self.course_state.course, self.user)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -204,7 +204,7 @@ class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase): ...@@ -204,7 +204,7 @@ class CourseMarkAsReviewedEmailTests(SiteMixin, TestCase):
) )
class CourseRunSendForReviewEmailTests(SiteMixin, TestCase): class CourseRunSendForReviewEmailTests(TestCase):
""" Tests for the CourseRun send for review email functionality. """ """ Tests for the CourseRun send for review email functionality. """
def setUp(self): def setUp(self):
...@@ -238,7 +238,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase): ...@@ -238,7 +238,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory( factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
) )
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user, self.site) emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user)
subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run) subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run)
self.assert_email_sent(subject, self.user_2) self.assert_email_sent(subject, self.user_2)
...@@ -247,7 +247,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase): ...@@ -247,7 +247,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory( factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
) )
emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user_2, self.site) emails.send_email_for_send_for_review_course_run(self.course_run_state.course_run, self.user_2)
subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run) subject = 'Review requested: {title} {run_number}'.format(title=self.course, run_number=self.course_key.run)
self.assert_email_sent(subject, self.user) self.assert_email_sent(subject, self.user)
...@@ -255,7 +255,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase): ...@@ -255,7 +255,7 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message.""" """ Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_for_send_for_review_course_run(self.course_run, self.user, self.site) emails.send_email_for_send_for_review_course_run(self.course_run, self.user)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -273,12 +273,12 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase): ...@@ -273,12 +273,12 @@ class CourseRunSendForReviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject) self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id}) page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
self.assertIn('View this course run in Publisher to review the changes or suggest edits.', body) self.assertIn('View this course run in Publisher to review the changes or suggest edits.', body)
class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase): class CourseRunMarkAsReviewedEmailTests(TestCase):
""" Tests for the CourseRun mark as reviewed email functionality. """ """ Tests for the CourseRun mark as reviewed email functionality. """
def setUp(self): def setUp(self):
...@@ -311,7 +311,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase): ...@@ -311,7 +311,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory( factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
) )
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user, self.site) emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user)
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
def test_email_sent_by_course_team(self): def test_email_sent_by_course_team(self):
...@@ -319,14 +319,14 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase): ...@@ -319,14 +319,14 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory( factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
) )
emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user_2, self.site) emails.send_email_for_mark_as_reviewed_course_run(self.course_run_state.course_run, self.user_2)
self.assert_email_sent(self.user) self.assert_email_sent(self.user)
def test_email_mark_as_reviewed_with_error(self): def test_email_mark_as_reviewed_with_error(self):
""" Verify that email failure log error message.""" """ Verify that email failure log error message."""
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, self.user, self.site) emails.send_email_for_mark_as_reviewed_course_run(self.course_run, self.user)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -342,7 +342,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase): ...@@ -342,7 +342,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
factories.CourseUserRoleFactory( factories.CourseUserRoleFactory(
course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user course=self.course, role=PublisherUserRole.ProjectCoordinator, user=self.user
) )
emails.send_email_to_publisher(self.course_run_state.course_run, self.user, self.site) emails.send_email_to_publisher(self.course_run_state.course_run, self.user)
self.assert_email_sent(self.user_3) self.assert_email_sent(self.user_3)
def test_email_to_publisher_with_error(self): def test_email_to_publisher_with_error(self):
...@@ -350,7 +350,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase): ...@@ -350,7 +350,7 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError): with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_to_publisher(self.course_run, self.user_3, self.site) emails.send_email_to_publisher(self.course_run, self.user_3)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -375,12 +375,12 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase): ...@@ -375,12 +375,12 @@ class CourseRunMarkAsReviewedEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject) self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id}) page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
self.assertIn('The review for this course run is complete.', body) self.assertIn('The review for this course run is complete.', body)
class CourseRunPreviewEmailTests(SiteMixin, TestCase): class CourseRunPreviewEmailTests(TestCase):
""" """
Tests for the course preview email functionality. Tests for the course preview email functionality.
""" """
...@@ -414,7 +414,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase): ...@@ -414,7 +414,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
lms_course_id = 'course-v1:edX+DemoX+Demo_Course' lms_course_id = 'course-v1:edX+DemoX+Demo_Course'
self.run_state.course_run.lms_course_id = lms_course_id self.run_state.course_run.lms_course_id = lms_course_id
emails.send_email_preview_accepted(self.run_state.course_run, self.site) emails.send_email_preview_accepted(self.run_state.course_run)
course_key = CourseKey.from_string(lms_course_id) course_key = CourseKey.from_string(lms_course_id)
subject = 'Publication requested: {course_name} {run_number}'.format( subject = 'Publication requested: {course_name} {run_number}'.format(
...@@ -426,7 +426,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase): ...@@ -426,7 +426,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject) self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.run_state.course_run.id}) page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': self.run_state.course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
self.assertIn('You can now publish this About page.', body) self.assertIn('You can now publish this About page.', body)
...@@ -440,7 +440,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase): ...@@ -440,7 +440,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
with self.assertRaises(Exception) as ex: with self.assertRaises(Exception) as ex:
self.assertEqual(str(ex.exception), message) self.assertEqual(str(ex.exception), message)
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_preview_accepted(self.run_state.course_run, self.site) emails.send_email_preview_accepted(self.run_state.course_run)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -457,7 +457,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase): ...@@ -457,7 +457,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
course_run.lms_course_id = 'course-v1:testX+testX1.0+2017T1' course_run.lms_course_id = 'course-v1:testX+testX1.0+2017T1'
course_run.save() course_run.save()
emails.send_email_preview_page_is_available(course_run, self.site) emails.send_email_preview_page_is_available(course_run)
course_key = CourseKey.from_string(course_run.lms_course_id) course_key = CourseKey.from_string(course_run.lms_course_id)
subject = 'Review requested: Preview for {course_name} {run_number}'.format( subject = 'Review requested: Preview for {course_name} {run_number}'.format(
...@@ -469,7 +469,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase): ...@@ -469,7 +469,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), subject) self.assertEqual(str(mail.outbox[0].subject), subject)
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': course_run.id}) page_path = reverse('publisher:publisher_course_run_detail', kwargs={'pk': course_run.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
self.assertIn('A preview is now available for the', body) self.assertIn('A preview is now available for the', body)
...@@ -477,7 +477,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase): ...@@ -477,7 +477,7 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
""" Verify that exception raised on email failure.""" """ Verify that exception raised on email failure."""
with self.assertRaises(Exception) as ex: with self.assertRaises(Exception) as ex:
emails.send_email_preview_page_is_available(self.run_state.course_run, self.site) emails.send_email_preview_page_is_available(self.run_state.course_run)
error_message = 'Failed to send email notifications for preview available of course-run {}'.format( error_message = 'Failed to send email notifications for preview available of course-run {}'.format(
self.run_state.course_run.id self.run_state.course_run.id
) )
...@@ -486,19 +486,19 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase): ...@@ -486,19 +486,19 @@ class CourseRunPreviewEmailTests(SiteMixin, TestCase):
def test_preview_available_email_with_notification_disabled(self): def test_preview_available_email_with_notification_disabled(self):
""" Verify that email not sent if notification disabled by user.""" """ Verify that email not sent if notification disabled by user."""
factories.UserAttributeFactory(user=self.course.course_team_admin, enable_email_notification=False) factories.UserAttributeFactory(user=self.course.course_team_admin, enable_email_notification=False)
emails.send_email_preview_page_is_available(self.run_state.course_run, self.site) emails.send_email_preview_page_is_available(self.run_state.course_run)
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
def test_preview_accepted_email_with_notification_disabled(self): def test_preview_accepted_email_with_notification_disabled(self):
""" Verify that preview accepted email not sent if notification disabled by user.""" """ Verify that preview accepted email not sent if notification disabled by user."""
factories.UserAttributeFactory(user=self.course.publisher, enable_email_notification=False) factories.UserAttributeFactory(user=self.course.publisher, enable_email_notification=False)
emails.send_email_preview_accepted(self.run_state.course_run, self.site) emails.send_email_preview_accepted(self.run_state.course_run)
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
class CourseRunPublishedEmailTests(SiteMixin, TestCase): class CourseRunPublishedEmailTests(TestCase):
""" """
Tests for course run published email functionality. Tests for course run published email functionality.
""" """
...@@ -527,7 +527,7 @@ class CourseRunPublishedEmailTests(SiteMixin, TestCase): ...@@ -527,7 +527,7 @@ class CourseRunPublishedEmailTests(SiteMixin, TestCase):
""" """
self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2' self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2'
self.course_run.save() self.course_run.save()
emails.send_course_run_published_email(self.course_run, self.site) emails.send_course_run_published_email(self.course_run)
course_key = CourseKey.from_string(self.course_run.lms_course_id) course_key = CourseKey.from_string(self.course_run.lms_course_id)
subject = 'Publication complete: About page for {course_name} {run_number}'.format( subject = 'Publication complete: About page for {course_name} {run_number}'.format(
...@@ -550,11 +550,11 @@ class CourseRunPublishedEmailTests(SiteMixin, TestCase): ...@@ -550,11 +550,11 @@ class CourseRunPublishedEmailTests(SiteMixin, TestCase):
) )
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError): with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with self.assertRaises(Exception) as ex: with self.assertRaises(Exception) as ex:
emails.send_course_run_published_email(self.course_run, self.site) emails.send_course_run_published_email(self.course_run)
self.assertEqual(str(ex.exception), message) self.assertEqual(str(ex.exception), message)
class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase): class CourseChangeRoleAssignmentEmailTests(TestCase):
""" """
Tests email functionality for course role assignment changed. Tests email functionality for course role assignment changed.
""" """
...@@ -575,7 +575,7 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase): ...@@ -575,7 +575,7 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
""" """
Verify that course role assignment chnage email functionality works fine. Verify that course role assignment chnage email functionality works fine.
""" """
emails.send_change_role_assignment_email(self.marketing_role, self.user, self.site) emails.send_change_role_assignment_email(self.marketing_role, self.user)
expected_subject = '{role_name} changed for {course_title}'.format( expected_subject = '{role_name} changed for {course_title}'.format(
role_name=self.marketing_role.get_role_display().lower(), role_name=self.marketing_role.get_role_display().lower(),
course_title=self.course.title course_title=self.course.title
...@@ -589,7 +589,7 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase): ...@@ -589,7 +589,7 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject) self.assertEqual(str(mail.outbox[0].subject), expected_subject)
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id}) page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
self.assertIn('has changed.', body) self.assertIn('has changed.', body)
...@@ -603,11 +603,11 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase): ...@@ -603,11 +603,11 @@ class CourseChangeRoleAssignmentEmailTests(SiteMixin, TestCase):
) )
with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError): with mock.patch('django.core.mail.message.EmailMessage.send', side_effect=TypeError):
with self.assertRaises(Exception) as ex: with self.assertRaises(Exception) as ex:
emails.send_change_role_assignment_email(self.marketing_role, self.user, self.site) emails.send_change_role_assignment_email(self.marketing_role, self.user)
self.assertEqual(str(ex.exception), message) self.assertEqual(str(ex.exception), message)
class SEOReviewEmailTests(SiteMixin, TestCase): class SEOReviewEmailTests(TestCase):
""" Tests for the seo review email functionality. """ """ Tests for the seo review email functionality. """
def setUp(self): def setUp(self):
...@@ -626,7 +626,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase): ...@@ -626,7 +626,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase):
""" Verify that email failure logs error message.""" """ Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_for_seo_review(self.course, self.site) emails.send_email_for_seo_review(self.course)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
...@@ -642,7 +642,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase): ...@@ -642,7 +642,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase):
Verify that seo review email functionality works fine. Verify that seo review email functionality works fine.
""" """
factories.CourseUserRoleFactory(course=self.course, role=PublisherUserRole.ProjectCoordinator) factories.CourseUserRoleFactory(course=self.course, role=PublisherUserRole.ProjectCoordinator)
emails.send_email_for_seo_review(self.course, self.site) emails.send_email_for_seo_review(self.course)
expected_subject = 'Legal review requested: {title}'.format(title=self.course.title) expected_subject = 'Legal review requested: {title}'.format(title=self.course.title)
self.assertEqual(len(mail.outbox), 1) self.assertEqual(len(mail.outbox), 1)
...@@ -652,7 +652,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase): ...@@ -652,7 +652,7 @@ class SEOReviewEmailTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject) self.assertEqual(str(mail.outbox[0].subject), expected_subject)
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id}) page_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=page_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=page_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
self.assertIn('determine OFAC status', body) self.assertIn('determine OFAC status', body)
...@@ -671,7 +671,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests): ...@@ -671,7 +671,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
) )
self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2' self.course_run.lms_course_id = 'course-v1:testX+test45+2017T2'
self.course_run.save() self.course_run.save()
emails.send_email_for_published_course_run_editing(self.course_run, self.site) emails.send_email_for_published_course_run_editing(self.course_run)
course_key = CourseKey.from_string(self.course_run.lms_course_id) course_key = CourseKey.from_string(self.course_run.lms_course_id)
...@@ -692,7 +692,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests): ...@@ -692,7 +692,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
""" Verify that email failure logs error message.""" """ Verify that email failure logs error message."""
with LogCapture(emails.logger.name) as l: with LogCapture(emails.logger.name) as l:
emails.send_email_for_published_course_run_editing(self.course_run, self.site) emails.send_email_for_published_course_run_editing(self.course_run)
l.check( l.check(
( (
emails.logger.name, emails.logger.name,
......
...@@ -6,7 +6,6 @@ from django.urls import reverse ...@@ -6,7 +6,6 @@ from django.urls import reverse
from django_fsm import TransitionNotAllowed from django_fsm import TransitionNotAllowed
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory
...@@ -511,7 +510,7 @@ class GroupOrganizationTests(TestCase): ...@@ -511,7 +510,7 @@ class GroupOrganizationTests(TestCase):
@ddt.ddt @ddt.ddt
class CourseStateTests(SiteMixin, TestCase): class CourseStateTests(TestCase):
""" Tests for the publisher `CourseState` model. """ """ Tests for the publisher `CourseState` model. """
@classmethod @classmethod
...@@ -549,7 +548,7 @@ class CourseStateTests(SiteMixin, TestCase): ...@@ -549,7 +548,7 @@ class CourseStateTests(SiteMixin, TestCase):
""" """
self.assertNotEqual(self.course_state.name, state) self.assertNotEqual(self.course_state.name, state)
self.course_state.change_state(state=state, user=self.user, site=self.site) self.course_state.change_state(state=state, user=self.user)
self.assertEqual(self.course_state.name, state) self.assertEqual(self.course_state.name, state)
...@@ -562,7 +561,7 @@ class CourseStateTests(SiteMixin, TestCase): ...@@ -562,7 +561,7 @@ class CourseStateTests(SiteMixin, TestCase):
self.assertEqual(self.course_state.name, CourseStateChoices.Draft) self.assertEqual(self.course_state.name, CourseStateChoices.Draft)
with self.assertRaises(TransitionNotAllowed): with self.assertRaises(TransitionNotAllowed):
self.course_state.change_state(state=CourseStateChoices.Review, user=self.user, site=self.site) self.course_state.change_state(state=CourseStateChoices.Review, user=self.user)
def test_can_send_for_review(self): def test_can_send_for_review(self):
""" """
...@@ -645,7 +644,7 @@ class CourseStateTests(SiteMixin, TestCase): ...@@ -645,7 +644,7 @@ class CourseStateTests(SiteMixin, TestCase):
@ddt.ddt @ddt.ddt
class CourseRunStateTests(SiteMixin, TestCase): class CourseRunStateTests(TestCase):
""" Tests for the publisher `CourseRunState` model. """ """ Tests for the publisher `CourseRunState` model. """
@classmethod @classmethod
...@@ -704,7 +703,7 @@ class CourseRunStateTests(SiteMixin, TestCase): ...@@ -704,7 +703,7 @@ class CourseRunStateTests(SiteMixin, TestCase):
Verify that we can change course-run state according to workflow. Verify that we can change course-run state according to workflow.
""" """
self.assertNotEqual(self.course_run_state.name, state) self.assertNotEqual(self.course_run_state.name, state)
self.course_run_state.change_state(state=state, user=self.user, site=self.site) self.course_run_state.change_state(state=state, user=self.user)
self.assertEqual(self.course_run_state.name, state) self.assertEqual(self.course_run_state.name, state)
def test_with_invalid_parent_course_state(self): def test_with_invalid_parent_course_state(self):
......
...@@ -20,13 +20,11 @@ from opaque_keys.edx.keys import CourseKey ...@@ -20,13 +20,11 @@ from opaque_keys.edx.keys import CourseKey
from pytz import timezone from pytz import timezone
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import (CourseFactory, OrganizationFactory, PersonFactory, from course_discovery.apps.course_metadata.tests.factories import CourseFactory, OrganizationFactory, PersonFactory
SubjectFactory)
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
from course_discovery.apps.publisher.choices import CourseRunStateChoices, CourseStateChoices, PublisherUserRole from course_discovery.apps.publisher.choices import CourseRunStateChoices, CourseStateChoices, PublisherUserRole
from course_discovery.apps.publisher.constants import (ADMIN_GROUP_NAME, INTERNAL_USER_GROUP_NAME, from course_discovery.apps.publisher.constants import (ADMIN_GROUP_NAME, INTERNAL_USER_GROUP_NAME,
...@@ -44,7 +42,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact ...@@ -44,7 +42,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt @ddt.ddt
class CreateCourseViewTests(SiteMixin, TestCase): class CreateCourseViewTests(TestCase):
""" Tests for the publisher `CreateCourseView`. """ """ Tests for the publisher `CreateCourseView`. """
def setUp(self): def setUp(self):
...@@ -63,6 +61,7 @@ class CreateCourseViewTests(SiteMixin, TestCase): ...@@ -63,6 +61,7 @@ class CreateCourseViewTests(SiteMixin, TestCase):
self.course = factories.CourseFactory() self.course = factories.CourseFactory()
self.course.organizations.add(self.organization_extension.organization) self.course.organizations.add(self.organization_extension.organization)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
# creating default organizations roles # creating default organizations roles
...@@ -270,7 +269,7 @@ class CreateCourseViewTests(SiteMixin, TestCase): ...@@ -270,7 +269,7 @@ class CreateCourseViewTests(SiteMixin, TestCase):
) )
class CreateCourseRunViewTests(SiteMixin, TestCase): class CreateCourseRunViewTests(TestCase):
""" Tests for the publisher `UpdateCourseRunView`. """ """ Tests for the publisher `UpdateCourseRunView`. """
def setUp(self): def setUp(self):
...@@ -300,6 +299,7 @@ class CreateCourseRunViewTests(SiteMixin, TestCase): ...@@ -300,6 +299,7 @@ class CreateCourseRunViewTests(SiteMixin, TestCase):
current_datetime = datetime.now(timezone('US/Central')) current_datetime = datetime.now(timezone('US/Central'))
self.course_run_dict['start'] = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S') self.course_run_dict['start'] = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
self.course_run_dict['end'] = (current_datetime + timedelta(days=3)).strftime('%Y-%m-%d %H:%M:%S') self.course_run_dict['end'] = (current_datetime + timedelta(days=3)).strftime('%Y-%m-%d %H:%M:%S')
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
def _pop_valuse_from_dict(self, data_dict, key_list): def _pop_valuse_from_dict(self, data_dict, key_list):
...@@ -562,7 +562,7 @@ class CreateCourseRunViewTests(SiteMixin, TestCase): ...@@ -562,7 +562,7 @@ class CreateCourseRunViewTests(SiteMixin, TestCase):
@ddt.ddt @ddt.ddt
class CourseRunDetailTests(SiteMixin, TestCase): class CourseRunDetailTests(TestCase):
""" Tests for the course-run detail view. """ """ Tests for the course-run detail view. """
def setUp(self): def setUp(self):
...@@ -763,8 +763,9 @@ class CourseRunDetailTests(SiteMixin, TestCase): ...@@ -763,8 +763,9 @@ class CourseRunDetailTests(SiteMixin, TestCase):
""" """
self.client.logout() self.client.logout()
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
site = Site.objects.get(pk=settings.SITE_ID)
comment = CommentFactory(content_object=self.course_run, user=self.user, site=self.site) comment = CommentFactory(content_object=self.course_run, user=self.user, site=site)
response = self.client.get(self.page_url) response = self.client.get(self.page_url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self._assert_credits_seats(response, self.wrapped_course_run.credit_seat) self._assert_credits_seats(response, self.wrapped_course_run.credit_seat)
...@@ -778,7 +779,7 @@ class CourseRunDetailTests(SiteMixin, TestCase): ...@@ -778,7 +779,7 @@ class CourseRunDetailTests(SiteMixin, TestCase):
# test decline comment appearing on detail page also. # test decline comment appearing on detail page also.
decline_comment = CommentFactory( decline_comment = CommentFactory(
content_object=self.course_run, content_object=self.course_run,
user=self.user, site=self.site, comment_type=CommentTypeChoices.Decline_Preview user=self.user, site=site, comment_type=CommentTypeChoices.Decline_Preview
) )
response = self.client.get(self.page_url) response = self.client.get(self.page_url)
self.assertContains(response, decline_comment.comment) self.assertContains(response, decline_comment.comment)
...@@ -1232,12 +1233,12 @@ class CourseRunDetailTests(SiteMixin, TestCase): ...@@ -1232,12 +1233,12 @@ class CourseRunDetailTests(SiteMixin, TestCase):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
@ddt.ddt @ddt.ddt
class DashboardTests(SiteMixin, TestCase): class DashboardTests(TestCase):
""" Tests for the `Dashboard`. """ """ Tests for the `Dashboard`. """
def setUp(self): def setUp(self):
super(DashboardTests, self).setUp() super(DashboardTests, self).setUp()
Site.objects.exclude(id=self.site.id).delete()
self.group_internal = Group.objects.get(name=INTERNAL_USER_GROUP_NAME) self.group_internal = Group.objects.get(name=INTERNAL_USER_GROUP_NAME)
self.group_project_coordinator = Group.objects.get(name=PROJECT_COORDINATOR_GROUP_NAME) self.group_project_coordinator = Group.objects.get(name=PROJECT_COORDINATOR_GROUP_NAME)
self.group_reviewer = Group.objects.get(name=REVIEWER_GROUP_NAME) self.group_reviewer = Group.objects.get(name=REVIEWER_GROUP_NAME)
...@@ -1276,18 +1277,7 @@ class DashboardTests(SiteMixin, TestCase): ...@@ -1276,18 +1277,7 @@ class DashboardTests(SiteMixin, TestCase):
def _create_course_assign_role(self, state, user, role): def _create_course_assign_role(self, state, user, role):
""" Create course-run-state, course-user-role and return course-run. """ """ Create course-run-state, course-user-role and return course-run. """
course = factories.CourseFactory( course_run_state = factories.CourseRunStateFactory(name=state, owner_role=role)
primary_subject=SubjectFactory(partner=self.partner),
secondary_subject=SubjectFactory(partner=self.partner),
tertiary_subject=SubjectFactory(partner=self.partner)
)
course_run = factories.CourseRunFactory(course=course)
course_run_state = factories.CourseRunStateFactory(
name=state,
owner_role=role,
course_run=course_run
)
factories.CourseUserRoleFactory(course=course_run_state.course_run.course, role=role, user=user) factories.CourseUserRoleFactory(course=course_run_state.course_run.course, role=role, user=user)
return course_run_state.course_run return course_run_state.course_run
...@@ -1494,7 +1484,8 @@ class DashboardTests(SiteMixin, TestCase): ...@@ -1494,7 +1484,8 @@ class DashboardTests(SiteMixin, TestCase):
Verify that site_name is available in context. Verify that site_name is available in context.
""" """
response = self.client.get(self.page_url) response = self.client.get(self.page_url)
self.assertEqual(response.context['site_name'], self.site.name) site = Site.objects.first()
self.assertEqual(response.context['site_name'], site.name)
def test_filters(self): def test_filters(self):
""" """
...@@ -1512,9 +1503,10 @@ class DashboardTests(SiteMixin, TestCase): ...@@ -1512,9 +1503,10 @@ class DashboardTests(SiteMixin, TestCase):
response = self.client.get(self.page_url) response = self.client.get(self.page_url)
site = Site.objects.first()
self._assert_filter_counts(response, 'All', 3) self._assert_filter_counts(response, 'All', 3)
self._assert_filter_counts(response, 'With Course Team', 2) self._assert_filter_counts(response, 'With Course Team', 2)
self._assert_filter_counts(response, 'With {site_name}'.format(site_name=self.site.name), 1) self._assert_filter_counts(response, 'With {site_name}'.format(site_name=site.name), 1)
def _assert_filter_counts(self, response, expected_label, count): def _assert_filter_counts(self, response, expected_label, count):
""" """
...@@ -1525,7 +1517,7 @@ class DashboardTests(SiteMixin, TestCase): ...@@ -1525,7 +1517,7 @@ class DashboardTests(SiteMixin, TestCase):
self.assertContains(response, expected_count, count=1) self.assertContains(response, expected_count, count=1)
class ToggleEmailNotificationTests(SiteMixin, TestCase): class ToggleEmailNotificationTests(TestCase):
""" Tests for `ToggleEmailNotification` view. """ """ Tests for `ToggleEmailNotification` view. """
def setUp(self): def setUp(self):
...@@ -1558,7 +1550,7 @@ class ToggleEmailNotificationTests(SiteMixin, TestCase): ...@@ -1558,7 +1550,7 @@ class ToggleEmailNotificationTests(SiteMixin, TestCase):
self.assertEqual(is_email_notification_enabled(user), is_enabled) self.assertEqual(is_email_notification_enabled(user), is_enabled)
class CourseListViewTests(SiteMixin, TestCase): class CourseListViewTests(TestCase):
""" Tests for `CourseListView` """ """ Tests for `CourseListView` """
def setUp(self): def setUp(self):
...@@ -1632,7 +1624,7 @@ class CourseListViewTests(SiteMixin, TestCase): ...@@ -1632,7 +1624,7 @@ class CourseListViewTests(SiteMixin, TestCase):
self.assertContains(response, 'Edit') self.assertContains(response, 'Edit')
class CourseDetailViewTests(SiteMixin, TestCase): class CourseDetailViewTests(TestCase):
""" Tests for the course detail view. """ """ Tests for the course detail view. """
def setUp(self): def setUp(self):
...@@ -2064,7 +2056,7 @@ class CourseDetailViewTests(SiteMixin, TestCase): ...@@ -2064,7 +2056,7 @@ class CourseDetailViewTests(SiteMixin, TestCase):
@ddt.ddt @ddt.ddt
class CourseEditViewTests(SiteMixin, TestCase): class CourseEditViewTests(TestCase):
""" Tests for the course edit view. """ """ Tests for the course edit view. """
def setUp(self): def setUp(self):
...@@ -2482,7 +2474,7 @@ class CourseEditViewTests(SiteMixin, TestCase): ...@@ -2482,7 +2474,7 @@ class CourseEditViewTests(SiteMixin, TestCase):
@ddt.ddt @ddt.ddt
class CourseRunEditViewTests(SiteMixin, TestCase): class CourseRunEditViewTests(TestCase):
""" Tests for the course run edit view. """ """ Tests for the course run edit view. """
def setUp(self): def setUp(self):
...@@ -2500,6 +2492,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase): ...@@ -2500,6 +2492,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase):
self.seat = factories.SeatFactory(course_run=self.course_run, type=Seat.VERIFIED, price=2) self.seat = factories.SeatFactory(course_run=self.course_run, type=Seat.VERIFIED, price=2)
self.course.organizations.add(self.organization_extension.organization) self.course.organizations.add(self.organization_extension.organization)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
current_datetime = datetime.now(timezone('US/Central')) current_datetime = datetime.now(timezone('US/Central'))
self.start_date_time = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S') self.start_date_time = (current_datetime + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
...@@ -2782,7 +2775,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase): ...@@ -2782,7 +2775,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase):
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body) self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path) page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
def test_studio_instance_with_course_team(self): def test_studio_instance_with_course_team(self):
...@@ -3011,7 +3004,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase): ...@@ -3011,7 +3004,7 @@ class CourseRunEditViewTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject) self.assertEqual(str(mail.outbox[0].subject), expected_subject)
class CourseRevisionViewTests(SiteMixin, TestCase): class CourseRevisionViewTests(TestCase):
""" Tests for CourseReview""" """ Tests for CourseReview"""
def setUp(self): def setUp(self):
...@@ -3063,7 +3056,7 @@ class CourseRevisionViewTests(SiteMixin, TestCase): ...@@ -3063,7 +3056,7 @@ class CourseRevisionViewTests(SiteMixin, TestCase):
return self.client.get(path=revision_path) return self.client.get(path=revision_path)
class CreateRunFromDashboardViewTests(SiteMixin, TestCase): class CreateRunFromDashboardViewTests(TestCase):
""" Tests for the publisher `CreateRunFromDashboardView`. """ """ Tests for the publisher `CreateRunFromDashboardView`. """
def setUp(self): def setUp(self):
...@@ -3163,7 +3156,7 @@ class CreateRunFromDashboardViewTests(SiteMixin, TestCase): ...@@ -3163,7 +3156,7 @@ class CreateRunFromDashboardViewTests(SiteMixin, TestCase):
self.assertEqual(str(mail.outbox[0].subject), expected_subject) self.assertEqual(str(mail.outbox[0].subject), expected_subject)
class CreateAdminImportCourseTest(SiteMixin, TestCase): class CreateAdminImportCourseTest(TestCase):
""" Tests for the publisher `CreateAdminImportCourse`. """ """ Tests for the publisher `CreateAdminImportCourse`. """
def setUp(self): def setUp(self):
......
...@@ -394,16 +394,14 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView): ...@@ -394,16 +394,14 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView):
if latest_run and latest_run.course_run_state.name == CourseRunStateChoices.Published: if latest_run and latest_run.course_run_state.name == CourseRunStateChoices.Published:
# If latest run of this course is published send an email to Publisher and don't change state. # If latest run of this course is published send an email to Publisher and don't change state.
send_email_for_published_course_run_editing(latest_run, self.request.site) send_email_for_published_course_run_editing(latest_run)
else: else:
user_role = self.object.course_user_roles.get(user=user) user_role = self.object.course_user_roles.get(user=user)
# Change course state to draft if marketing not yet reviewed or # Change course state to draft if marketing not yet reviewed or
# if marketing person updating the course. # if marketing person updating the course.
if not self.object.course_state.marketing_reviewed or user_role.role == PublisherUserRole.MarketingReviewer: if not self.object.course_state.marketing_reviewed or user_role.role == PublisherUserRole.MarketingReviewer:
if self.object.course_state.name != CourseStateChoices.Draft: if self.object.course_state.name != CourseStateChoices.Draft:
self.object.course_state.change_state( self.object.course_state.change_state(state=CourseStateChoices.Draft, user=user)
state=CourseStateChoices.Draft, user=user, site=self.request.site
)
# Change ownership if user role not equal to owner role. # Change ownership if user role not equal to owner role.
if self.object.course_state.owner_role != user_role.role: if self.object.course_state.owner_role != user_role.role:
...@@ -601,7 +599,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView): ...@@ -601,7 +599,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView):
) )
messages.success(request, success_msg) messages.success(request, success_msg)
emails.send_email_for_course_creation(parent_course, course_run, request.site) emails.send_email_for_course_creation(parent_course, course_run)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id})) return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as error: # pylint: disable=broad-except except Exception as error: # pylint: disable=broad-except
# pylint: disable=no-member # pylint: disable=no-member
...@@ -742,10 +740,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix ...@@ -742,10 +740,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state = course_run.course_run_state course_run_state = course_run.course_run_state
if course_run_state.name not in immutable_states: if course_run_state.name not in immutable_states:
course_run_state.change_state(state=CourseStateChoices.Draft, user=user, site=request.site) course_run_state.change_state(state=CourseStateChoices.Draft, user=user)
if course_run.lms_course_id and lms_course_id != course_run.lms_course_id: if course_run.lms_course_id and lms_course_id != course_run.lms_course_id:
emails.send_email_for_studio_instance_created(course_run, site=request.site) emails.send_email_for_studio_instance_created(course_run)
# pylint: disable=no-member # pylint: disable=no-member
messages.success(request, _('Course run updated successfully.')) messages.success(request, _('Course run updated successfully.'))
...@@ -759,7 +757,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix ...@@ -759,7 +757,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state.change_owner_role(user_role) course_run_state.change_owner_role(user_role)
if CourseRunStateChoices.Published == course_run_state.name: if CourseRunStateChoices.Published == course_run_state.name:
send_email_for_published_course_run_editing(course_run, request.site) send_email_for_published_course_run_editing(course_run)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id})) return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
......
...@@ -3,7 +3,6 @@ import json ...@@ -3,7 +3,6 @@ import json
from django.test import TestCase from django.test import TestCase
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE
from course_discovery.apps.publisher.tests.factories import CourseRunFactory from course_discovery.apps.publisher.tests.factories import CourseRunFactory
...@@ -12,7 +11,7 @@ from course_discovery.apps.publisher_comments.models import Comments ...@@ -12,7 +11,7 @@ from course_discovery.apps.publisher_comments.models import Comments
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class PostCommentTests(SiteMixin, TestCase): class PostCommentTests(TestCase):
def generate_data(self, obj): def generate_data(self, obj):
"""Generate data for the form.""" """Generate data for the form."""
...@@ -40,7 +39,7 @@ class PostCommentTests(SiteMixin, TestCase): ...@@ -40,7 +39,7 @@ class PostCommentTests(SiteMixin, TestCase):
self.assertEqual(comment.user_email, generated_data['email']) self.assertEqual(comment.user_email, generated_data['email'])
class UpdateCommentTests(SiteMixin, TestCase): class UpdateCommentTests(TestCase):
def setUp(self): def setUp(self):
super(UpdateCommentTests, self).setUp() super(UpdateCommentTests, self).setUp()
......
from django.conf import settings
from django.contrib.sites.models import Site
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import factories from course_discovery.apps.publisher.tests import factories
from course_discovery.apps.publisher_comments.forms import CommentsAdminForm from course_discovery.apps.publisher_comments.forms import CommentsAdminForm
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class AdminTests(SiteMixin, TestCase): class AdminTests(TestCase):
""" Tests Admin page and customize form.""" """ Tests Admin page and customize form."""
def setUp(self): def setUp(self):
super(AdminTests, self).setUp() super(AdminTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.course = factories.CourseFactory() self.course = factories.CourseFactory()
self.comment = CommentFactory(content_object=self.course, user=self.user, site=self.site) self.comment = CommentFactory(content_object=self.course, user=self.user, site=self.site)
......
import ddt import ddt
import mock import mock
from django.conf import settings
from django.contrib.sites.models import Site
from django.core import mail from django.core import mail
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.publisher.choices import PublisherUserRole from course_discovery.apps.publisher.choices import PublisherUserRole
...@@ -19,7 +20,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact ...@@ -19,7 +20,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt @ddt.ddt
class CommentsEmailTests(SiteMixin, TestCase): class CommentsEmailTests(TestCase):
""" Tests for the e-mail functionality for course, course-run and seats. """ """ Tests for the e-mail functionality for course, course-run and seats. """
def setUp(self): def setUp(self):
...@@ -29,6 +30,8 @@ class CommentsEmailTests(SiteMixin, TestCase): ...@@ -29,6 +30,8 @@ class CommentsEmailTests(SiteMixin, TestCase):
self.user_2 = UserFactory() self.user_2 = UserFactory()
self.user_3 = UserFactory() self.user_3 = UserFactory()
self.site = Site.objects.get(pk=settings.SITE_ID)
self.organization_extension = factories.OrganizationExtensionFactory() self.organization_extension = factories.OrganizationExtensionFactory()
self.seat = factories.SeatFactory() self.seat = factories.SeatFactory()
......
...@@ -480,6 +480,8 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20 ...@@ -480,6 +480,8 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20
DEFAULT_PARTNER_ID = None DEFAULT_PARTNER_ID = None
# See: https://docs.djangoproject.com/en/dev/ref/settings/#site-id
SITE_ID = 1
COMMENTS_APP = 'course_discovery.apps.publisher_comments' COMMENTS_APP = 'course_discovery.apps.publisher_comments'
TAGGIT_CASE_INSENSITIVE = True TAGGIT_CASE_INSENSITIVE = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment