Commit e251012e by Vedran Karacic Committed by Vedran Karačić

Filter API data by partner

LEARNER-1119
parent 0c869dd7
import logging import logging
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db.models import QuerySet from django.db.models import QuerySet
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
...@@ -13,7 +12,6 @@ from guardian.shortcuts import get_objects_for_user ...@@ -13,7 +12,6 @@ from guardian.shortcuts import get_objects_for_user
from rest_framework.exceptions import NotFound, PermissionDenied from rest_framework.exceptions import NotFound, PermissionDenied
from course_discovery.apps.api.utils import cast2int from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
from course_discovery.apps.course_metadata.choices import ProgramStatus from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Organization, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Organization, Program
...@@ -93,7 +91,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter): ...@@ -93,7 +91,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
# Return data for the default partner, if no partner is requested # Return data for the default partner, if no partner is requested
if not any(field in filters for field in ('partner', 'partner_exact')): if not any(field in filters for field in ('partner', 'partner_exact')):
filters['partner'] = Partner.objects.get(pk=settings.DEFAULT_PARTNER_ID).short_code filters['partner'] = request.site.partner.short_code
return filters return filters
......
...@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer): ...@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
tags = TagListSerializerField() tags = TagListSerializerField()
@classmethod @classmethod
def prefetch_queryset(cls): def prefetch_queryset(cls, partner):
return Organization.objects.all().select_related('partner').prefetch_related('tags') return Organization.objects.filter(partner=partner).select_related('partner').prefetch_related('tags')
class Meta(MinimalOrganizationSerializer.Meta): class Meta(MinimalOrganizationSerializer.Meta):
fields = MinimalOrganizationSerializer.Meta.fields + ( fields = MinimalOrganizationSerializer.Meta.fields + (
...@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer): ...@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url = serializers.SerializerMethodField() marketing_url = serializers.SerializerMethodField()
@classmethod @classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None): def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
# Explicitly check for None to avoid returning all Courses when the # Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty. # queryset passed in happens to be empty.
queryset = queryset if queryset is not None else Course.objects.all() queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related( return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items', 'expected_learning_items',
'prerequisites', 'prerequisites',
'subjects', 'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)), Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
) )
class Meta(MinimalCourseSerializer.Meta): class Meta(MinimalCourseSerializer.Meta):
...@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer): ...@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
programs = serializers.SerializerMethodField() programs = serializers.SerializerMethodField()
@classmethod @classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None): def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
""" """
Similar to the CourseSerializer's prefetch_queryset, but prefetches a Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset. filtered CourseRun queryset.
""" """
queryset = queryset if queryset is not None else Course.objects.all() queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related( return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items', 'expected_learning_items',
'prerequisites', 'prerequisites',
'subjects', 'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)), Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
) )
def get_course_runs(self, course): def get_course_runs(self, course):
...@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer): ...@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs = serializers.SerializerMethodField() course_runs = serializers.SerializerMethodField()
@classmethod @classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None): def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
""" """
Similar to the CourseSerializer's prefetch_queryset, but prefetches a Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset. filtered CourseRun queryset.
""" """
queryset = queryset if queryset is not None else Course.objects.all() queryset = queryset if queryset is not None else Course.objects.filter(partner=partner)
return queryset.select_related('level_type', 'video', 'partner').prefetch_related( return queryset.select_related('level_type', 'video', 'partner').prefetch_related(
'expected_learning_items', 'expected_learning_items',
'prerequisites', 'prerequisites',
'subjects', 'subjects',
Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)), Prefetch('course_runs', queryset=CourseRunSerializer.prefetch_queryset(queryset=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('sponsoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
) )
def get_course_runs(self, course): def get_course_runs(self, course):
...@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer): ...@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all()) type = serializers.SlugRelatedField(slug_field='name', queryset=ProgramType.objects.all())
@classmethod @classmethod
def prefetch_queryset(cls): def prefetch_queryset(cls, partner):
return Program.objects.all().select_related('type', 'partner').prefetch_related( return Program.objects.filter(partner=partner).select_related('type', 'partner').prefetch_related(
'excluded_course_runs', 'excluded_course_runs',
# `type` is serialized by a third-party serializer. Providing this field name allows us to # `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to # prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
...@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer): ...@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
applicable_seat_types = serializers.SerializerMethodField() applicable_seat_types = serializers.SerializerMethodField()
@classmethod @classmethod
def prefetch_queryset(cls): def prefetch_queryset(cls, partner):
""" """
Prefetch the related objects that will be serialized with a `Program`. Prefetch the related objects that will be serialized with a `Program`.
...@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer): ...@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
chain of related fields from programs to course runs (i.e., we want control over chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching). the querysets that we're prefetching).
""" """
return Program.objects.all().select_related('type', 'video', 'partner').prefetch_related( return Program.objects.filter(partner=partner).select_related('type', 'video', 'partner').prefetch_related(
'excluded_course_runs', 'excluded_course_runs',
'expected_learning_items', 'expected_learning_items',
'faq', 'faq',
...@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer): ...@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
'type__applicable_seat_types', 'type__applicable_seat_types',
# We need the full Course prefetch here to get CourseRun information that methods on the Program # We need the full Course prefetch here to get CourseRun information that methods on the Program
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer. # model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset()), Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset()), Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()), Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()), Prefetch('individual_endorsements', queryset=EndorsementSerializer.prefetch_queryset()),
) )
......
from django.conf import settings
from django.contrib.sites.models import Site
from django.test import RequestFactory
from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory
class SiteMixin(object):
def setUp(self):
super(SiteMixin, self).setUp()
domain = 'testserver.fake'
self.client = self.client_class(SERVER_NAME=domain)
Site.objects.all().delete()
self.site = SiteFactory(id=settings.SITE_ID, domain=domain)
self.partner = PartnerFactory(site=self.site)
self.request = RequestFactory(SERVER_NAME=self.site.domain).get('')
self.request.site = self.site
...@@ -21,6 +21,7 @@ from course_discovery.apps.api.serializers import ( ...@@ -21,6 +21,7 @@ from course_discovery.apps.api.serializers import (
ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer, ProgramSerializer, ProgramTypeSerializer, SeatSerializer, SubjectSerializer, TypeaheadCourseRunSearchSerializer,
TypeaheadProgramSearchSerializer, VideoSerializer TypeaheadProgramSearchSerializer, VideoSerializer
) )
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.models import User from course_discovery.apps.core.models import User
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
...@@ -96,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase): ...@@ -96,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
class MinimalCourseSerializerTests(TestCase): class MinimalCourseSerializerTests(SiteMixin, TestCase):
serializer_class = MinimalCourseSerializer serializer_class = MinimalCourseSerializer
def get_expected_data(self, course, request): def get_expected_data(self, course, request):
...@@ -114,8 +115,8 @@ class MinimalCourseSerializerTests(TestCase): ...@@ -114,8 +115,8 @@ class MinimalCourseSerializerTests(TestCase):
def test_data(self): def test_data(self):
request = make_request() request = make_request()
organizations = OrganizationFactory() organizations = OrganizationFactory(partner=self.partner)
course = CourseFactory(authoring_organizations=[organizations]) course = CourseFactory(authoring_organizations=[organizations], partner=self.partner)
CourseRunFactory.create_batch(2, course=course) CourseRunFactory.create_batch(2, course=course)
serializer = self.serializer_class(course, context={'request': request}) serializer = self.serializer_class(course, context={'request': request})
expected = self.get_expected_data(course, request) expected = self.get_expected_data(course, request)
...@@ -178,9 +179,10 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests): ...@@ -178,9 +179,10 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.request = make_request() self.request = make_request()
self.course = CourseFactory() self.course = CourseFactory(partner=self.partner)
self.deleted_program = ProgramFactory( self.deleted_program = ProgramFactory(
courses=[self.course], courses=[self.course],
partner=self.partner,
status=ProgramStatus.Deleted status=ProgramStatus.Deleted
) )
......
import ddt import ddt
import pytest
from django.conf import settings
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.test import RequestFactory, TestCase from django.test import RequestFactory, TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.views import api_docs_permission_denied_handler from course_discovery.apps.api.views import api_docs_permission_denied_handler
from course_discovery.apps.core.tests.factories import PartnerFactory, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
@pytest.mark.django_db class TestApiDocs(APITestCase):
class TestApiDocs:
""" """
Regression tests introduced following LEARNER-1590. Regression tests introduced following LEARNER-1590.
""" """
path = reverse('api_docs') path = reverse('api_docs')
def test_api_docs(self, admin_client): def test_api_docs(self):
""" """
Verify that the API docs are available to authenticated clients. Verify that the API docs are available to authenticated clients.
""" """
PartnerFactory(pk=settings.DEFAULT_PARTNER_ID) user = UserFactory(is_staff=True)
self.client.login(username=user.username, password=USER_PASSWORD)
response = admin_client.get(self.path)
response = self.client.get(self.path)
assert response.status_code == 200 assert response.status_code == 200
def test_api_docs_redirect(self, client): def test_api_docs_redirect(self):
""" """
Verify that unauthenticated clients are redirected. Verify that unauthenticated clients are redirected.
""" """
response = client.get(self.path) response = self.client.get(self.path)
assert response.status_code == 302 assert response.status_code == 302
......
...@@ -4,6 +4,7 @@ import json ...@@ -4,6 +4,7 @@ import json
import responses import responses
from django.conf import settings from django.conf import settings
from rest_framework.test import APITestCase as RestAPITestCase
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.serializers import ( from course_discovery.apps.api.serializers import (
...@@ -11,6 +12,7 @@ from course_discovery.apps.api.serializers import ( ...@@ -11,6 +12,7 @@ from course_discovery.apps.api.serializers import (
CourseWithProgramsSerializer, FlattenedCourseRunWithCourseSerializer, MinimalProgramSerializer, CourseWithProgramsSerializer, FlattenedCourseRunWithCourseSerializer, MinimalProgramSerializer,
OrganizationSerializer, PersonSerializer, ProgramSerializer, ProgramTypeSerializer OrganizationSerializer, PersonSerializer, ProgramSerializer, ProgramTypeSerializer
) )
from course_discovery.apps.api.tests.mixins import SiteMixin
class SerializationMixin(object): class SerializationMixin(object):
...@@ -88,3 +90,7 @@ class OAuth2Mixin(object): ...@@ -88,3 +90,7 @@ class OAuth2Mixin(object):
content_type='application/json', content_type='application/json',
status=status status=status
) )
class APITestCase(SiteMixin, RestAPITestCase):
pass
...@@ -7,10 +7,9 @@ import ddt ...@@ -7,10 +7,9 @@ import ddt
import pytz import pytz
from lxml import etree from lxml import etree
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import AffiliateWindowSerializer from course_discovery.apps.api.serializers import AffiliateWindowSerializer
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
...@@ -46,8 +45,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP ...@@ -46,8 +45,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
def test_affiliate_with_supported_seats(self): def test_affiliate_with_supported_seats(self):
""" Verify that endpoint returns course runs for verified and professional seats only. """ """ Verify that endpoint returns course runs for verified and professional seats only. """
with self.assertNumQueries(8): response = self.client.get(self.affiliate_url)
response = self.client.get(self.affiliate_url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
root = ET.fromstring(response.content) root = ET.fromstring(response.content)
...@@ -130,7 +128,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP ...@@ -130,7 +128,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
# Superusers can view all catalogs # Superusers can view all catalogs
self.client.force_authenticate(superuser) self.client.force_authenticate(superuser)
with self.assertNumQueries(4): with self.assertNumQueries(5):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -140,7 +138,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP ...@@ -140,7 +138,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
catalog.viewers = [self.user] catalog.viewers = [self.user]
with self.assertNumQueries(7): with self.assertNumQueries(8):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
......
...@@ -8,10 +8,9 @@ import pytz ...@@ -8,10 +8,9 @@ import pytz
import responses import responses
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
from course_discovery.apps.api.v1.tests.test_views.mixins import OAuth2Mixin, SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, OAuth2Mixin, SerializationMixin
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.catalogs.tests.factories import CatalogFactory from course_discovery.apps.catalogs.tests.factories import CatalogFactory
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
...@@ -31,6 +30,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -31,6 +30,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
def setUp(self): def setUp(self):
super(CatalogViewSetTests, self).setUp() super(CatalogViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.force_authenticate(self.user) self.client.force_authenticate(self.user)
self.catalog = CatalogFactory(query='title:abc*') self.catalog = CatalogFactory(query='title:abc*')
enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=30) enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=30)
...@@ -172,7 +172,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -172,7 +172,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# to be included. # to be included.
filtered_course_run = CourseRunFactory(course=course) filtered_course_run = CourseRunFactory(course=course)
with self.assertNumQueries(16): with self.assertNumQueries(18):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
...@@ -185,8 +185,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -185,8 +185,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# Any course appearing in the response must have at least one serialized run. # Any course appearing in the response must have at least one serialized run.
assert len(response.data['results'][0]['course_runs']) > 0 assert len(response.data['results'][0]['course_runs']) > 0
else: else:
with self.assertNumQueries(2): response = self.client.get(url)
response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
assert response.data['results'] == [] assert response.data['results'] == []
...@@ -218,7 +217,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -218,7 +217,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id}) url = reverse('api:v1:catalog-csv', kwargs={'id': self.catalog.id})
with self.assertNumQueries(17): with self.assertNumQueries(18):
response = self.client.get(url) response = self.client.get(url)
course_run = self.serialize_catalog_flat_course_run(self.course_run) course_run = self.serialize_catalog_flat_course_run(self.course_run)
......
...@@ -4,19 +4,16 @@ import urllib ...@@ -4,19 +4,16 @@ import urllib
import ddt import ddt
import pytz import pytz
from django.conf import settings
from django.db.models.functions import Lower from django.db.models.functions import Lower
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory, APITestCase from rest_framework.test import APIRequestFactory
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.course_metadata.tests.factories import ( from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory, SeatFactory
CourseRunFactory, PartnerFactory, ProgramFactory, SeatFactory
)
@ddt.ddt @ddt.ddt
...@@ -25,10 +22,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -25,10 +22,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
super(CourseRunViewSetTests, self).setUp() super(CourseRunViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user) self.client.force_authenticate(self.user)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
self.course_run = CourseRunFactory(course__partner=self.partner) self.course_run = CourseRunFactory(course__partner=self.partner)
self.course_run_2 = CourseRunFactory(course__partner=self.partner) self.course_run_2 = CourseRunFactory(course__partner=self.partner)
self.refresh_index() self.refresh_index()
...@@ -170,15 +163,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -170,15 +163,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
key=lambda course_run: course_run['key']) key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted) self.assertListEqual(actual_sorted, expected_sorted)
def test_list_query_invalid_partner(self):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query = 'title:Some random title'
url = '{root}?q={query}&partner={partner}'.format(root=reverse('api:v1:course_run-list'), query=query,
partner='foo')
response = self.client.get(url)
self.assertEqual(response.status_code, 400)
def assert_list_results(self, url, expected, extra_context=None): def assert_list_results(self, url, expected, extra_context=None):
expected = sorted(expected, key=lambda course_run: course_run.key.lower()) expected = sorted(expected, key=lambda course_run: course_run.key.lower())
response = self.client.get(url) response = self.client.get(url)
...@@ -256,7 +240,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -256,7 +240,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
'course_run_ids': self.course_run.key, 'course_run_ids': self.course_run.key,
}) })
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs) url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
self.assertEqual( self.assertEqual(
...@@ -268,18 +251,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -268,18 +251,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
} }
) )
def test_contains_single_course_run_invalid_partner(self):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs = urllib.parse.urlencode({
'query': 'id:course*',
'course_run_ids': self.course_run.key,
'partner': 'foo'
})
url = '{}?{}'.format(reverse('api:v1:course_run-contains'), qs)
response = self.client.get(url)
assert response.status_code == 400
def test_contains_multiple_course_runs(self): def test_contains_multiple_course_runs(self):
qs = urllib.parse.urlencode({ qs = urllib.parse.urlencode({
'query': 'id:course*', 'query': 'id:course*',
......
...@@ -4,9 +4,8 @@ import ddt ...@@ -4,9 +4,8 @@ import ddt
import pytz import pytz
from django.db.models.functions import Lower from django.db.models.functions import Lower
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import Course from course_discovery.apps.course_metadata.models import Course
...@@ -22,14 +21,15 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -22,14 +21,15 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
def setUp(self): def setUp(self):
super(CourseViewSetTests, self).setUp() super(CourseViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
self.course = CourseFactory() self.course = CourseFactory(partner=self.partner)
def test_get(self): def test_get(self):
""" Verify the endpoint returns the details for a single course. """ """ Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(18): with self.assertNumQueries(20):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.serialize_course(self.course)) self.assertEqual(response.data, self.serialize_course(self.course))
...@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """ """ Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(11): with self.assertNumQueries(13):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.get('programs'), []) self.assertEqual(response.data.get('programs'), [])
...@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted) ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1' url += '?include_deleted_programs=1'
with self.assertNumQueries(22): with self.assertNumQueries(24):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual( self.assertEqual(
...@@ -187,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -187,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """ """ Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list') url = reverse('api:v1:course-list')
with self.assertNumQueries(24): with self.assertNumQueries(26):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertListEqual( self.assertListEqual(
...@@ -203,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase): ...@@ -203,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query = 'title:' + title query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query) url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
with self.assertNumQueries(37): with self.assertNumQueries(39):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
def test_list_key_filter(self): def test_list_key_filter(self):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """ """ Verify the endpoint returns a list of courses filtered by the specified keys. """
courses = CourseFactory.create_batch(3) courses = CourseFactory.create_batch(3, partner=self.partner)
courses = sorted(courses, key=lambda course: course.key.lower()) courses = sorted(courses, key=lambda course: course.key.lower())
keys = ','.join([course.key for course in courses]) keys = ','.join([course.key for course in courses])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course-list'), keys=keys)
with self.assertNumQueries(37): with self.assertNumQueries(39):
response = self.client.get(url) response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True)) self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......
import uuid import uuid
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import Organization, OrganizationFactory from course_discovery.apps.course_metadata.tests.factories import Organization, OrganizationFactory
...@@ -14,6 +13,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -14,6 +13,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def setUp(self): def setUp(self):
super(OrganizationViewSetTests, self).setUp() super(OrganizationViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.request.user = self.user
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
def test_authentication(self): def test_authentication(self):
...@@ -27,17 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -27,17 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def assert_response_data_valid(self, response, organizations, many=True): def assert_response_data_valid(self, response, organizations, many=True):
""" Asserts the response data (only) contains the expected organizations. """ """ Asserts the response data (only) contains the expected organizations. """
actual = response.data actual = response.data
serializer_data = self.serialize_organization(organizations, many=many)
if many: if many:
actual = actual['results'] actual = actual['results']
self.assertEqual(actual, self.serialize_organization(organizations, many=many)) self.assertCountEqual(actual, serializer_data)
def assert_list_uuid_filter(self, organizations): def assert_list_uuid_filter(self, organizations, expected_query_count):
""" Asserts the list endpoint supports filtering by UUID. """ """ Asserts the list endpoint supports filtering by UUID. """
organizations = sorted(organizations, key=lambda o: o.created)
with self.assertNumQueries(5): with self.assertNumQueries(expected_query_count):
uuids = ','.join([organization.uuid.hex for organization in organizations]) uuids = ','.join([organization.uuid.hex for organization in organizations])
url = '{root}?uuids={uuids}'.format(root=self.list_path, uuids=uuids) url = '{root}?uuids={uuids}'.format(root=self.list_path, uuids=uuids)
response = self.client.get(url) response = self.client.get(url)
...@@ -45,9 +45,8 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -45,9 +45,8 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assert_response_data_valid(response, organizations) self.assert_response_data_valid(response, organizations)
def assert_list_tag_filter(self, organizations, tags, expected_query_count=5): def assert_list_tag_filter(self, organizations, tags, expected_query_count=7):
""" Asserts the list endpoint supports filtering by tags. """ """ Asserts the list endpoint supports filtering by tags. """
with self.assertNumQueries(expected_query_count): with self.assertNumQueries(expected_query_count):
tags = ','.join(tags) tags = ','.join(tags)
url = '{root}?tags={tags}'.format(root=self.list_path, tags=tags) url = '{root}?tags={tags}'.format(root=self.list_path, tags=tags)
...@@ -58,10 +57,9 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -58,10 +57,9 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list(self): def test_list(self):
""" Verify the endpoint returns a list of all organizations. """ """ Verify the endpoint returns a list of all organizations. """
OrganizationFactory.create_batch(3, partner=self.partner)
OrganizationFactory.create_batch(3) with self.assertNumQueries(7):
with self.assertNumQueries(5):
response = self.client.get(self.list_path) response = self.client.get(self.list_path)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -70,22 +68,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -70,22 +68,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_list_uuid_filter(self): def test_list_uuid_filter(self):
""" Verify the endpoint returns a list of organizations filtered by UUID. """ """ Verify the endpoint returns a list of organizations filtered by UUID. """
organizations = OrganizationFactory.create_batch(3) organizations = OrganizationFactory.create_batch(3, partner=self.partner)
# Test with a single UUID # Test with a single UUID
self.assert_list_uuid_filter([organizations[0]]) self.assert_list_uuid_filter([organizations[0]], 7)
# Test with multiple UUIDs # Test with multiple UUIDs
self.assert_list_uuid_filter(organizations) self.assert_list_uuid_filter(organizations, 7)
def test_list_tag_filter(self): def test_list_tag_filter(self):
""" Verify the endpoint returns a list of organizations filtered by tag. """ """ Verify the endpoint returns a list of organizations filtered by tag. """
tag = 'test-org' tag = 'test-org'
organizations = OrganizationFactory.create_batch(2) organizations = OrganizationFactory.create_batch(2, partner=self.partner)
# If no organizations have been tagged, the endpoint should not return any data # If no organizations have been tagged, the endpoint should not return any data
self.assert_list_tag_filter([], [tag], expected_query_count=3) self.assert_list_tag_filter([], [tag], expected_query_count=5)
# Tagged organizations should be returned # Tagged organizations should be returned
organizations[0].tags.add(tag) organizations[0].tags.add(tag)
...@@ -99,7 +97,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase): ...@@ -99,7 +97,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def test_retrieve(self): def test_retrieve(self):
""" Verify the endpoint returns details for a single organization. """ """ Verify the endpoint returns details for a single organization. """
organization = OrganizationFactory() organization = OrganizationFactory(partner=self.partner)
url = reverse('api:v1:organization-detail', kwargs={'uuid': organization.uuid}) url = reverse('api:v1:organization-detail', kwargs={'uuid': organization.uuid})
response = self.client.get(url) response = self.client.get(url)
......
# pylint: disable=redefined-builtin,no-member # pylint: disable=redefined-builtin,no-member
import ddt import ddt
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db import IntegrityError from django.db import IntegrityError
from mock import mock from mock import mock
...@@ -8,20 +7,20 @@ from rest_framework.reverse import reverse ...@@ -8,20 +7,20 @@ from rest_framework.reverse import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.api.v1.views.people import logger as people_logger from course_discovery.apps.api.v1.views.people import logger as people_logger
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.models import Person from course_discovery.apps.course_metadata.models import Person
from course_discovery.apps.course_metadata.people import MarketingSitePeople from course_discovery.apps.course_metadata.people import MarketingSitePeople
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.course_metadata.tests.factories import (OrganizationFactory, PartnerFactory, PersonFactory, from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory, PositionFactory
PositionFactory)
User = get_user_model() User = get_user_model()
@ddt.ddt @ddt.ddt
class PersonViewSetTests(SerializationMixin, APITestCase): class PersonViewSetTests(SerializationMixin, SiteMixin, APITestCase):
""" Tests for the person resource. """ """ Tests for the person resource. """
people_list_url = reverse('api:v1:person-list') people_list_url = reverse('api:v1:person-list')
...@@ -29,13 +28,9 @@ class PersonViewSetTests(SerializationMixin, APITestCase): ...@@ -29,13 +28,9 @@ class PersonViewSetTests(SerializationMixin, APITestCase):
super(PersonViewSetTests, self).setUp() super(PersonViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user) self.client.force_authenticate(self.user)
self.person = PersonFactory() self.person = PersonFactory(partner=self.partner)
PositionFactory(person=self.person) self.organization = OrganizationFactory(partner=self.partner)
self.organization = OrganizationFactory() PositionFactory(person=self.person, organization=self.organization)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self.partner = PartnerFactory(id=settings.DEFAULT_PARTNER_ID)
toggle_switch('publish_person_to_marketing_site', True) toggle_switch('publish_person_to_marketing_site', True)
self.expected_node = { self.expected_node = {
'resource': 'node', '' 'resource': 'node', ''
......
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, SerializationMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.models import ProgramType from course_discovery.apps.course_metadata.models import ProgramType
from course_discovery.apps.course_metadata.tests.factories import ProgramTypeFactory from course_discovery.apps.course_metadata.tests.factories import ProgramTypeFactory
...@@ -28,7 +27,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase): ...@@ -28,7 +27,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all program types. """ """ Verify the endpoint returns a list of all program types. """
ProgramTypeFactory.create_batch(4) ProgramTypeFactory.create_batch(4)
expected = ProgramType.objects.all() expected = ProgramType.objects.all()
with self.assertNumQueries(5): with self.assertNumQueries(6):
response = self.client.get(self.list_path) response = self.client.get(self.list_path)
assert response.status_code == 200 assert response.status_code == 200
...@@ -39,7 +38,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase): ...@@ -39,7 +38,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
program_type = ProgramTypeFactory() program_type = ProgramTypeFactory()
url = reverse('api:v1:program_type-detail', kwargs={'slug': program_type.slug}) url = reverse('api:v1:program_type-detail', kwargs={'slug': program_type.slug})
with self.assertNumQueries(4): with self.assertNumQueries(5):
response = self.client.get(url) response = self.client.get(url)
assert response.status_code == 200 assert response.status_code == 200
......
...@@ -3,13 +3,12 @@ import json ...@@ -3,13 +3,12 @@ import json
import urllib.parse import urllib.parse
import ddt import ddt
from django.conf import settings
from django.urls import reverse from django.urls import reverse
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from rest_framework.test import APITestCase
from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer, from course_discovery.apps.api.serializers import (CourseRunSearchSerializer, ProgramSearchSerializer,
TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer) TypeaheadCourseRunSearchSerializer, TypeaheadProgramSearchSerializer)
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.views.search import TypeaheadSearchView from course_discovery.apps.api.v1.views.search import TypeaheadSearchView
from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
...@@ -88,14 +87,8 @@ class SynonymTestMixin: ...@@ -88,14 +87,8 @@ class SynonymTestMixin:
self.assertDictEqual(response1, response2) self.assertDictEqual(response1, response2)
class DefaultPartnerMixin:
def setUp(self):
super(DefaultPartnerMixin, self).setUp()
self.partner = PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
@ddt.ddt @ddt.ddt
class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin, class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
APITestCase): APITestCase):
""" Tests for CourseRunSearchViewSet. """ """ Tests for CourseRunSearchViewSet. """
faceted_path = reverse('api:v1:search-course_runs-facets') faceted_path = reverse('api:v1:search-course_runs-facets')
...@@ -162,7 +155,9 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -162,7 +155,9 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
return course_run, response_data return course_run, response_data
def build_facet_url(self, params): def build_facet_url(self, params):
return 'http://testserver{path}?{query}'.format(path=self.faceted_path, query=urllib.parse.urlencode(params)) return 'http://testserver.fake{path}?{query}'.format(
path=self.faceted_path, query=urllib.parse.urlencode(params)
)
def test_invalid_query_facet(self): def test_invalid_query_facet(self):
""" Verify the endpoint returns HTTP 400 if an invalid facet is requested. """ """ Verify the endpoint returns HTTP 400 if an invalid facet is requested. """
...@@ -271,7 +266,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -271,7 +266,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
) )
self.reindex_courses(program) self.reindex_courses(program)
with self.assertNumQueries(4): with self.assertNumQueries(5):
response = self.get_response('software', faceted=False) response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -295,7 +290,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -295,7 +290,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory(courses=[course_run.course], status=program_status) ProgramFactory(courses=[course_run.course], status=program_status)
self.reindex_courses(active_program) self.reindex_courses(active_program)
with self.assertNumQueries(5): with self.assertNumQueries(6):
response = self.get_response('software', faceted=False) response = self.get_response('software', faceted=False)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -313,7 +308,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -313,7 +308,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
@ddt.ddt @ddt.ddt
class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, ElasticsearchTestMixin, class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase): SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-all-facets') path = reverse('api:v1:search-all-facets')
...@@ -438,7 +433,7 @@ class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login ...@@ -438,7 +433,7 @@ class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
assert expected == actual assert expected == actual
class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin, class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, ElasticsearchTestMixin,
SynonymTestMixin, APITestCase): SynonymTestMixin, APITestCase):
path = reverse('api:v1:search-typeahead') path = reverse('api:v1:search-typeahead')
...@@ -620,23 +615,3 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin, ...@@ -620,23 +615,3 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
self.serialize_program(harvard_program)] self.serialize_program(harvard_program)]
} }
self.assertDictEqual(response.data, expected) self.assertDictEqual(response.data, expected)
def test_typeahead_partner_filter(self):
""" Ensure that a partner param limits results to that partner. """
course_runs = []
programs = []
for partner in ['edx', 'other']:
title = 'Belongs to partner ' + partner
partner = PartnerFactory(short_code=partner)
course_runs.append(CourseRunFactory(title=title, course=CourseFactory(partner=partner)))
programs.append(ProgramFactory(
title=title, partner=partner,
status=ProgramStatus.Active
))
response = self.get_response({'q': 'partner'}, 'edx')
self.assertEqual(response.status_code, 200)
edx_course_run = course_runs[0]
edx_program = programs[0]
self.assertDictEqual(response.data, {'course_runs': [self.serialize_course_run(edx_course_run)],
'programs': [self.serialize_program(edx_program)]})
...@@ -35,18 +35,3 @@ def prefetch_related_objects_for_courses(queryset): ...@@ -35,18 +35,3 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course']) queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course']) queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
...@@ -94,6 +94,7 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -94,6 +94,7 @@ class CatalogViewSet(viewsets.ModelViewSet):
course_runs = CourseRun.objects.active().enrollable().marketable() course_runs = CourseRun.objects.active().enrollable().marketable()
queryset = serializers.CatalogCourseSerializer.prefetch_queryset( queryset = serializers.CatalogCourseSerializer.prefetch_queryset(
self.request.site.partner,
queryset=queryset, queryset=queryset,
course_runs=course_runs course_runs=course_runs
) )
......
...@@ -9,14 +9,13 @@ from rest_framework.response import Response ...@@ -9,14 +9,13 @@ from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import ProxiedPagination from course_discovery.apps.api.pagination import ProxiedPagination
from course_discovery.apps.api.utils import get_query_param from course_discovery.apps.api.utils import get_query_param
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member # pylint: disable=no-member
class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet): class CourseRunViewSet(viewsets.ModelViewSet):
""" CourseRun resource. """ """ CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter) filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter filter_class = filters.CourseRunFilter
...@@ -43,7 +42,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet): ...@@ -43,7 +42,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
multiple: false multiple: false
""" """
q = self.request.query_params.get('q') q = self.request.query_params.get('q')
partner = self.get_partner() partner = self.request.site.partner
if q: if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code)) qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
...@@ -80,12 +79,6 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet): ...@@ -80,12 +79,6 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
type: string type: string
paramType: query paramType: query
multiple: false multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
- name: hidden - name: hidden
description: Filter based on wether the course run is hidden from search. description: Filter based on wether the course run is hidden from search.
required: false required: false
...@@ -166,7 +159,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet): ...@@ -166,7 +159,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
""" """
query = request.GET.get('query') query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids') course_run_ids = request.GET.get('course_run_ids')
partner = self.get_partner() partner = self.request.site.partner
if query and course_run_ids: if query and course_run_ids:
course_run_ids = course_run_ids.split(',') course_run_ids = course_run_ids.split(',')
......
...@@ -18,7 +18,6 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -18,7 +18,6 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
filter_class = filters.CourseFilter filter_class = filters.CourseFilter
lookup_field = 'key' lookup_field = 'key'
lookup_value_regex = COURSE_ID_REGEX lookup_value_regex = COURSE_ID_REGEX
queryset = Course.objects.all()
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = serializers.CourseWithProgramsSerializer serializer_class = serializers.CourseWithProgramsSerializer
...@@ -27,16 +26,17 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -27,16 +26,17 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class = ProxiedPagination pagination_class = ProxiedPagination
def get_queryset(self): def get_queryset(self):
partner = self.request.site.partner
q = self.request.query_params.get('q') q = self.request.query_params.get('q')
if q: if q:
queryset = Course.search(q) queryset = Course.search(q)
queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset) queryset = self.get_serializer_class().prefetch_queryset(queryset=queryset, partner=partner)
else: else:
if get_query_param(self.request, 'include_hidden_course_runs'): if get_query_param(self.request, 'include_hidden_course_runs'):
course_runs = CourseRun.objects.all() course_runs = CourseRun.objects.filter(course__partner=partner)
else: else:
course_runs = CourseRun.objects.exclude(hidden=True) course_runs = CourseRun.objects.filter(course__partner=partner).exclude(hidden=True)
if get_query_param(self.request, 'marketable_course_runs_only'): if get_query_param(self.request, 'marketable_course_runs_only'):
course_runs = course_runs.marketable().active() course_runs = course_runs.marketable().active()
...@@ -49,7 +49,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -49,7 +49,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
queryset = self.get_serializer_class().prefetch_queryset( queryset = self.get_serializer_class().prefetch_queryset(
queryset=self.queryset, queryset=self.queryset,
course_runs=course_runs course_runs=course_runs,
partner=partner
) )
return queryset.order_by(Lower('key')) return queryset.order_by(Lower('key'))
......
...@@ -15,13 +15,16 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -15,13 +15,16 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = 'uuid' lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f-]+' lookup_value_regex = '[0-9a-f-]+'
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
queryset = serializers.OrganizationSerializer.prefetch_queryset()
serializer_class = serializers.OrganizationSerializer serializer_class = serializers.OrganizationSerializer
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future # Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# versions of this API should only support the system default, PageNumberPagination. # versions of this API should only support the system default, PageNumberPagination.
pagination_class = ProxiedPagination pagination_class = ProxiedPagination
def get_queryset(self):
partner = self.request.site.partner
return serializers.OrganizationSerializer.prefetch_queryset(partner=partner)
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
""" Retrieve a list of all organizations. """ """ Retrieve a list of all organizations. """
return super(OrganizationViewSet, self).list(request, *args, **kwargs) return super(OrganizationViewSet, self).list(request, *args, **kwargs)
......
...@@ -7,7 +7,6 @@ from rest_framework.response import Response ...@@ -7,7 +7,6 @@ from rest_framework.response import Response
from course_discovery.apps.api import serializers from course_discovery.apps.api import serializers
from course_discovery.apps.api.pagination import PageNumberPagination from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.exceptions import MarketingSiteAPIClientException, PersonToMarketingException from course_discovery.apps.course_metadata.exceptions import MarketingSiteAPIClientException, PersonToMarketingException
from course_discovery.apps.course_metadata.people import MarketingSitePeople from course_discovery.apps.course_metadata.people import MarketingSitePeople
...@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) ...@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=no-member # pylint: disable=no-member
class PersonViewSet(PartnerMixin, viewsets.ModelViewSet): class PersonViewSet(viewsets.ModelViewSet):
""" PersonSerializer resource. """ """ PersonSerializer resource. """
lookup_field = 'uuid' lookup_field = 'uuid'
...@@ -30,7 +29,7 @@ class PersonViewSet(PartnerMixin, viewsets.ModelViewSet): ...@@ -30,7 +29,7 @@ class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
""" Create a new person. """ """ Create a new person. """
person_data = request.data person_data = request.data
partner = self.get_partner() partner = request.site.partner
person_data['partner'] = partner.id person_data['partner'] = partner.id
serializer = self.get_serializer(data=person_data) serializer = self.get_serializer(data=person_data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
......
...@@ -32,7 +32,8 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet): ...@@ -32,7 +32,8 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
def get_queryset(self): def get_queryset(self):
# This method prevents prefetches on the program queryset from "stacking," # This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property. # which happens when the queryset is stored in a class property.
return self.get_serializer_class().prefetch_queryset() partner = self.request.site.partner
return self.get_serializer_class().prefetch_queryset(partner)
def get_serializer_context(self, *args, **kwargs): def get_serializer_context(self, *args, **kwargs):
context = super().get_serializer_context(*args, **kwargs) context = super().get_serializer_context(*args, **kwargs)
...@@ -89,7 +90,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet): ...@@ -89,7 +90,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
if get_query_param(self.request, 'uuids_only'): if get_query_param(self.request, 'uuids_only'):
# DRF serializers don't have good support for simple, flat # DRF serializers don't have good support for simple, flat
# representations like the one we want here. # representations like the one we want here.
queryset = self.filter_queryset(Program.objects.all()) queryset = self.filter_queryset(Program.objects.filter(partner=self.request.site.partner))
uuids = queryset.values_list('uuid', flat=True) uuids = queryset.values_list('uuid', flat=True)
return Response(uuids) return Response(uuids)
......
...@@ -12,7 +12,6 @@ from rest_framework.response import Response ...@@ -12,7 +12,6 @@ from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
...@@ -119,7 +118,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet): ...@@ -119,7 +118,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(PartnerMixin, APIView): class TypeaheadSearchView(APIView):
""" Typeahead for courses and programs. """ """ Typeahead for courses and programs. """
RESULT_COUNT = 3 RESULT_COUNT = 3
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
...@@ -181,7 +180,7 @@ class TypeaheadSearchView(PartnerMixin, APIView): ...@@ -181,7 +180,7 @@ class TypeaheadSearchView(PartnerMixin, APIView):
type: string type: string
""" """
query = request.query_params.get('q') query = request.query_params.get('q')
partner = self.get_partner() partner = request.site.partner
if not query: if not query:
raise ValidationError("The 'q' querystring parameter is required for searching.") raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query, partner) course_runs, programs = self.get_results(query, partner)
......
...@@ -3,10 +3,11 @@ import json ...@@ -3,10 +3,11 @@ import json
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
class UserAutocompleteTests(TestCase): class UserAutocompleteTests(SiteMixin, TestCase):
""" Tests for user autocomplete lookups.""" """ Tests for user autocomplete lookups."""
def setUp(self): def setUp(self):
......
from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import UserThrottleRate from course_discovery.apps.core.models import UserThrottleRate
from course_discovery.apps.core.tests.factories import USER_PASSWORD, PartnerFactory, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.throttles import OverridableUserRateThrottle from course_discovery.apps.core.throttles import OverridableUserRateThrottle
class RateLimitingTest(APITestCase): class RateLimitingTest(SiteMixin, APITestCase):
""" """
Testing rate limiting of API calls. Testing rate limiting of API calls.
""" """
...@@ -16,8 +16,6 @@ class RateLimitingTest(APITestCase): ...@@ -16,8 +16,6 @@ class RateLimitingTest(APITestCase):
def setUp(self): def setUp(self):
super(RateLimitingTest, self).setUp() super(RateLimitingTest, self).setUp()
PartnerFactory(pk=settings.DEFAULT_PARTNER_ID)
self.url = reverse('api_docs') self.url = reverse('api_docs')
self.user = UserFactory() self.user = UserFactory()
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
......
...@@ -9,18 +9,20 @@ from django.test.utils import override_settings ...@@ -9,18 +9,20 @@ from django.test.utils import override_settings
from django.urls import reverse from django.urls import reverse
from django.utils.encoding import force_text from django.utils.encoding import force_text
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.constants import Status from course_discovery.apps.core.constants import Status
User = get_user_model() User = get_user_model()
class HealthTests(TestCase): class HealthTests(SiteMixin, TestCase):
"""Tests of the health endpoint.""" """Tests of the health endpoint."""
def test_all_services_available(self): def test_all_services_available(self):
"""Test that the endpoint reports when all services are healthy.""" """Test that the endpoint reports when all services are healthy."""
self._assert_health(200, Status.OK, Status.OK) self._assert_health(200, Status.OK, Status.OK)
@mock.patch('django.contrib.sites.middleware.get_current_site', mock.Mock(return_value=None))
def test_database_outage(self): def test_database_outage(self):
"""Test that the endpoint reports when the database is unavailable.""" """Test that the endpoint reports when the database is unavailable."""
with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError): with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.cursor', side_effect=DatabaseError):
...@@ -42,7 +44,7 @@ class HealthTests(TestCase): ...@@ -42,7 +44,7 @@ class HealthTests(TestCase):
self.assertJSONEqual(force_text(response.content), expected_data) self.assertJSONEqual(force_text(response.content), expected_data)
class AutoAuthTests(TestCase): class AutoAuthTests(SiteMixin, TestCase):
""" Auto Auth view tests. """ """ Auto Auth view tests. """
AUTO_AUTH_PATH = reverse('auto_auth') AUTO_AUTH_PATH = reverse('auto_auth')
......
...@@ -11,6 +11,7 @@ from selenium.webdriver.support import expected_conditions as EC ...@@ -11,6 +11,7 @@ from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import Select from selenium.webdriver.support.ui import Select
from selenium.webdriver.support.wait import WebDriverWait from selenium.webdriver.support.wait import WebDriverWait
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import Partner from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
...@@ -23,7 +24,7 @@ from course_discovery.apps.course_metadata.tests import factories ...@@ -23,7 +24,7 @@ from course_discovery.apps.course_metadata.tests import factories
# pylint: disable=no-member # pylint: disable=no-member
@ddt.ddt @ddt.ddt
class AdminTests(TestCase): class AdminTests(SiteMixin, TestCase):
""" Tests Admin page.""" """ Tests Admin page."""
def setUp(self): def setUp(self):
...@@ -190,7 +191,7 @@ class AdminTests(TestCase): ...@@ -190,7 +191,7 @@ class AdminTests(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
class ProgramAdminFunctionalTests(LiveServerTestCase): class ProgramAdminFunctionalTests(SiteMixin, LiveServerTestCase):
""" Functional Tests for Admin page.""" """ Functional Tests for Admin page."""
# Required for access to initial data loaded in migrations (e.g., LanguageTags). # Required for access to initial data loaded in migrations (e.g., LanguageTags).
serialized_rollback = True serialized_rollback = True
...@@ -224,7 +225,6 @@ class ProgramAdminFunctionalTests(LiveServerTestCase): ...@@ -224,7 +225,6 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
# ContentTypeManager uses a cache to speed up ContentType retrieval. This # ContentTypeManager uses a cache to speed up ContentType retrieval. This
# cache persists across tests. This is fine in the context of a regular # cache persists across tests. This is fine in the context of a regular
# TestCase which uses a transaction to reset the database between tests. # TestCase which uses a transaction to reset the database between tests.
...@@ -238,6 +238,9 @@ class ProgramAdminFunctionalTests(LiveServerTestCase): ...@@ -238,6 +238,9 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
# stale ContentType objects from being used. # stale ContentType objects from being used.
ContentType.objects.clear_cache() ContentType.objects.clear_cache()
self.site.domain = self.live_server_url.strip('http://')
self.site.save()
self.course_runs = factories.CourseRunFactory.create_batch(2) self.course_runs = factories.CourseRunFactory.create_batch(2)
self.courses = [course_run.course for course_run in self.course_runs] self.courses = [course_run.course for course_run in self.course_runs]
...@@ -349,7 +352,7 @@ class ProgramAdminFunctionalTests(LiveServerTestCase): ...@@ -349,7 +352,7 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
self.assertEqual(self.program.subtitle, subtitle) self.assertEqual(self.program.subtitle, subtitle)
class ProgramEligibilityFilterTests(TestCase): class ProgramEligibilityFilterTests(SiteMixin, TestCase):
""" Tests for Program Eligibility Filter class. """ """ Tests for Program Eligibility Filter class. """
parameter_name = 'eligible_for_one_click_purchase' parameter_name = 'eligible_for_one_click_purchase'
......
...@@ -5,6 +5,7 @@ import ddt ...@@ -5,6 +5,7 @@ import ddt
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.tests.factories import ( from course_discovery.apps.course_metadata.tests.factories import (
CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory
...@@ -16,7 +17,7 @@ from course_discovery.apps.publisher.tests import factories ...@@ -16,7 +17,7 @@ from course_discovery.apps.publisher.tests import factories
@ddt.ddt @ddt.ddt
class AutocompleteTests(TestCase): class AutocompleteTests(SiteMixin, TestCase):
""" Tests for autocomplete lookups.""" """ Tests for autocomplete lookups."""
def setUp(self): def setUp(self):
super(AutocompleteTests, self).setUp() super(AutocompleteTests, self).setUp()
...@@ -118,7 +119,7 @@ class AutocompleteTests(TestCase): ...@@ -118,7 +119,7 @@ class AutocompleteTests(TestCase):
@ddt.ddt @ddt.ddt
class AutoCompletePersonTests(TestCase): class AutoCompletePersonTests(SiteMixin, TestCase):
""" """
Tests for person autocomplete lookups Tests for person autocomplete lookups
""" """
......
...@@ -2,17 +2,17 @@ import datetime ...@@ -2,17 +2,17 @@ import datetime
import urllib.parse import urllib.parse
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.api.v1.tests.test_views.test_search import ( from course_discovery.apps.api.v1.tests.test_views.test_search import (
DefaultPartnerMixin, ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin ElasticsearchTestMixin, LoginMixin, SerializationMixin, SynonymTestMixin
) )
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory, ProgramFactory from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory, ProgramFactory
from course_discovery.apps.edx_catalog_extensions.api.serializers import DistinctCountsAggregateFacetSearchSerializer from course_discovery.apps.edx_catalog_extensions.api.serializers import DistinctCountsAggregateFacetSearchSerializer
class DistinctCountsAggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, LoginMixin, class DistinctCountsAggregateSearchViewSetTests(SerializationMixin, LoginMixin,
ElasticsearchTestMixin, SynonymTestMixin, APITestCase): ElasticsearchTestMixin, SynonymTestMixin, APITestCase):
path = reverse('extensions:api:v1:search-all-facets') path = reverse('extensions:api:v1:search-all-facets')
......
...@@ -4,13 +4,14 @@ import ddt ...@@ -4,13 +4,14 @@ import ddt
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
# pylint: disable=no-member # pylint: disable=no-member
@ddt.ddt @ddt.ddt
class AutocompleteTests(TestCase): class AutocompleteTests(SiteMixin, TestCase):
""" Tests for autocomplete lookups.""" """ Tests for autocomplete lookups."""
def setUp(self): def setUp(self):
super(AutocompleteTests, self).setUp() super(AutocompleteTests, self).setUp()
......
...@@ -40,7 +40,8 @@ class CourseUserRoleSerializer(serializers.ModelSerializer): ...@@ -40,7 +40,8 @@ class CourseUserRoleSerializer(serializers.ModelSerializer):
former_user = instance.user former_user = instance.user
instance = super(CourseUserRoleSerializer, self).update(instance, validated_data) instance = super(CourseUserRoleSerializer, self).update(instance, validated_data)
if not instance.role == PublisherUserRole.CourseTeam: if not instance.role == PublisherUserRole.CourseTeam:
send_change_role_assignment_email(instance, former_user) request = self.context['request']
send_change_role_assignment_email(instance, former_user, request.site)
return instance return instance
...@@ -104,6 +105,7 @@ class CourseRunSerializer(serializers.ModelSerializer): ...@@ -104,6 +105,7 @@ class CourseRunSerializer(serializers.ModelSerializer):
instance = super(CourseRunSerializer, self).update(instance, validated_data) instance = super(CourseRunSerializer, self).update(instance, validated_data)
preview_url = validated_data.get('preview_url') preview_url = validated_data.get('preview_url')
lms_course_id = validated_data.get('lms_course_id') lms_course_id = validated_data.get('lms_course_id')
request = self.context['request']
if preview_url: if preview_url:
# Change ownership to CourseTeam. # Change ownership to CourseTeam.
...@@ -111,10 +113,10 @@ class CourseRunSerializer(serializers.ModelSerializer): ...@@ -111,10 +113,10 @@ class CourseRunSerializer(serializers.ModelSerializer):
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
if preview_url: if preview_url:
send_email_preview_page_is_available(instance) send_email_preview_page_is_available(instance, site=request.site)
elif lms_course_id: elif lms_course_id:
send_email_for_studio_instance_created(instance) send_email_for_studio_instance_created(instance, site=request.site)
return instance return instance
...@@ -167,7 +169,7 @@ class CourseStateSerializer(serializers.ModelSerializer): ...@@ -167,7 +169,7 @@ class CourseStateSerializer(serializers.ModelSerializer):
state = validated_data.get('name') state = validated_data.get('name')
request = self.context.get('request') request = self.context.get('request')
try: try:
instance.change_state(state=state, user=request.user) instance.change_state(state=state, user=request.user, site=request.site)
except TransitionNotAllowed: except TransitionNotAllowed:
# pylint: disable=no-member # pylint: disable=no-member
raise serializers.ValidationError( raise serializers.ValidationError(
...@@ -204,7 +206,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer): ...@@ -204,7 +206,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
if state: if state:
try: try:
instance.change_state(state=state, user=request.user) instance.change_state(state=state, user=request.user, site=request.site)
except TransitionNotAllowed: except TransitionNotAllowed:
# pylint: disable=no-member # pylint: disable=no-member
raise serializers.ValidationError( raise serializers.ValidationError(
...@@ -223,6 +225,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer): ...@@ -223,6 +225,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
instance.save() instance.save()
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
send_email_preview_accepted(instance.course_run) send_email_preview_accepted(instance.course_run, request.site)
return instance return instance
...@@ -5,6 +5,7 @@ from django.test import RequestFactory, TestCase ...@@ -5,6 +5,7 @@ from django.test import RequestFactory, TestCase
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
...@@ -20,7 +21,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour ...@@ -20,7 +21,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour
OrganizationExtensionFactory, SeatFactory) OrganizationExtensionFactory, SeatFactory)
class CourseUserRoleSerializerTests(TestCase): class CourseUserRoleSerializerTests(SiteMixin, TestCase):
serializer_class = CourseUserRoleSerializer serializer_class = CourseUserRoleSerializer
def setUp(self): def setUp(self):
...@@ -28,6 +29,7 @@ class CourseUserRoleSerializerTests(TestCase): ...@@ -28,6 +29,7 @@ class CourseUserRoleSerializerTests(TestCase):
self.request = RequestFactory() self.request = RequestFactory()
self.course_user_role = CourseUserRoleFactory(role=PublisherUserRole.MarketingReviewer) self.course_user_role = CourseUserRoleFactory(role=PublisherUserRole.MarketingReviewer)
self.request.user = self.course_user_role.user self.request.user = self.course_user_role.user
self.request.site = self.site
def get_expected_data(self): def get_expected_data(self):
""" Helper method which will return expected serialize data. """ """ Helper method which will return expected serialize data. """
...@@ -138,7 +140,7 @@ class CourseRunSerializerTests(TestCase): ...@@ -138,7 +140,7 @@ class CourseRunSerializerTests(TestCase):
""" """
self.course_run.preview_url = '' self.course_run.preview_url = ''
self.course_run.save() self.course_run.save()
serializer = self.serializer_class(self.course_run) serializer = self.serializer_class(self.course_run, context={'request': self.request})
serializer.update(self.course_run, {'preview_url': 'https://example.com/abc/course'}) serializer.update(self.course_run, {'preview_url': 'https://example.com/abc/course'})
self.assertEqual(self.course_state.owner_role, PublisherUserRole.CourseTeam) self.assertEqual(self.course_state.owner_role, PublisherUserRole.CourseTeam)
...@@ -246,13 +248,12 @@ class CourseRevisionSerializerTests(TestCase): ...@@ -246,13 +248,12 @@ class CourseRevisionSerializerTests(TestCase):
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
class CourseStateSerializerTests(TestCase): class CourseStateSerializerTests(SiteMixin, TestCase):
serializer_class = CourseStateSerializer serializer_class = CourseStateSerializer
def setUp(self): def setUp(self):
super(CourseStateSerializerTests, self).setUp() super(CourseStateSerializerTests, self).setUp()
self.course_state = CourseStateFactory(name=CourseStateChoices.Draft) self.course_state = CourseStateFactory(name=CourseStateChoices.Draft)
self.request = RequestFactory()
self.user = UserFactory() self.user = UserFactory()
self.request.user = self.user self.request.user = self.user
...@@ -289,14 +290,13 @@ class CourseStateSerializerTests(TestCase): ...@@ -289,14 +290,13 @@ class CourseStateSerializerTests(TestCase):
serializer.update(self.course_state, data) serializer.update(self.course_state, data)
class CourseRunStateSerializerTests(TestCase): class CourseRunStateSerializerTests(SiteMixin, TestCase):
serializer_class = CourseRunStateSerializer serializer_class = CourseRunStateSerializer
def setUp(self): def setUp(self):
super(CourseRunStateSerializerTests, self).setUp() super(CourseRunStateSerializerTests, self).setUp()
self.run_state = CourseRunStateFactory(name=CourseRunStateChoices.Draft) self.run_state = CourseRunStateFactory(name=CourseRunStateChoices.Draft)
self.course_run = self.run_state.course_run self.course_run = self.run_state.course_run
self.request = RequestFactory()
self.user = UserFactory() self.user = UserFactory()
self.request.user = self.user self.request.user = self.user
CourseStateFactory(name=CourseStateChoices.Approved, course=self.course_run.course) CourseStateFactory(name=CourseStateChoices.Approved, course=self.course_run.course)
......
...@@ -4,7 +4,6 @@ from urllib.parse import quote ...@@ -4,7 +4,6 @@ from urllib.parse import quote
import ddt import ddt
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.contrib.sites.models import Site
from django.core import mail from django.core import mail
from django.db import IntegrityError from django.db import IntegrityError
from django.test import TestCase from django.test import TestCase
...@@ -14,6 +13,7 @@ from mock import mock, patch ...@@ -14,6 +13,7 @@ from mock import mock, patch
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
...@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories ...@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories
@ddt.ddt @ddt.ddt
class CourseRoleAssignmentViewTests(TestCase): class CourseRoleAssignmentViewTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(CourseRoleAssignmentViewTests, self).setUp() super(CourseRoleAssignmentViewTests, self).setUp()
...@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(TestCase): ...@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(TestCase):
self.assertEqual(len(mail.outbox), 1) self.assertEqual(len(mail.outbox), 1)
class OrganizationGroupUserViewTests(TestCase): class OrganizationGroupUserViewTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(OrganizationGroupUserViewTests, self).setUp() super(OrganizationGroupUserViewTests, self).setUp()
...@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(TestCase): ...@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(TestCase):
) )
class UpdateCourseRunViewTests(TestCase): class UpdateCourseRunViewTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(UpdateCourseRunViewTests, self).setUp() super(UpdateCourseRunViewTests, self).setUp()
...@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(TestCase): ...@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(TestCase):
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
self.assertIn(expected_body, body) self.assertIn(expected_body, body)
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path) page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
def test_update_preview_url(self): def test_update_preview_url(self):
...@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(TestCase): ...@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(TestCase):
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
class CourseRevisionDetailViewTests(TestCase): class CourseRevisionDetailViewTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(CourseRevisionDetailViewTests, self).setUp() super(CourseRevisionDetailViewTests, self).setUp()
...@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(TestCase): ...@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(TestCase):
return self.client.get(path=course_revision_path) return self.client.get(path=course_revision_path)
class ChangeCourseStateViewTests(TestCase): class ChangeCourseStateViewTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(ChangeCourseStateViewTests, self).setUp() super(ChangeCourseStateViewTests, self).setUp()
...@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(TestCase): ...@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(TestCase):
body = mail.outbox[0].body.strip() body = mail.outbox[0].body.strip()
object_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id}) object_path = reverse('publisher:publisher_course_detail', kwargs={'pk': self.course.id})
page_url = 'https://{host}{path}'.format(host=Site.objects.get_current().domain.strip('/'), path=object_path) page_url = 'https://{host}{path}'.format(host=self.site.domain.strip('/'), path=object_path)
self.assertIn(page_url, body) self.assertIn(page_url, body)
def test_change_course_state_with_error(self): def test_change_course_state_with_error(self):
...@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(TestCase): ...@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(TestCase):
self._assert_email_sent(course_team_user, subject) self._assert_email_sent(course_team_user, subject)
class ChangeCourseRunStateViewTests(TestCase): class ChangeCourseRunStateViewTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(ChangeCourseRunStateViewTests, self).setUp() super(ChangeCourseRunStateViewTests, self).setUp()
...@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(TestCase): ...@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(TestCase):
self.assertIn('has been published', mail.outbox[0].body.strip()) self.assertIn('has been published', mail.outbox[0].body.strip())
class RevertCourseByRevisionTests(TestCase): class RevertCourseByRevisionTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(RevertCourseByRevisionTests, self).setUp() super(RevertCourseByRevisionTests, self).setUp()
...@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(TestCase): ...@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(TestCase):
return self.client.put(path=course_revision_path) return self.client.put(path=course_revision_path)
class CoursesAutoCompleteTests(TestCase): class CoursesAutoCompleteTests(SiteMixin, TestCase):
""" Tests for course autocomplete.""" """ Tests for course autocomplete."""
def setUp(self): def setUp(self):
...@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(TestCase): ...@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(TestCase):
self.assertEqual(len(data['results']), expected_length) self.assertEqual(len(data['results']), expected_length)
class AcceptAllByRevisionTests(TestCase): class AcceptAllByRevisionTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(AcceptAllByRevisionTests, self).setUp() super(AcceptAllByRevisionTests, self).setUp()
......
...@@ -617,7 +617,7 @@ class CourseState(TimeStampedModel, ChangedByMixin): ...@@ -617,7 +617,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
# TODO: send email etc. # TODO: send email etc.
pass pass
def change_state(self, state, user): def change_state(self, state, user, site=None):
""" """
Change course workflow state and ownership also send emails if required. Change course workflow state and ownership also send emails if required.
""" """
...@@ -632,12 +632,12 @@ class CourseState(TimeStampedModel, ChangedByMixin): ...@@ -632,12 +632,12 @@ class CourseState(TimeStampedModel, ChangedByMixin):
elif user_role.role == PublisherUserRole.CourseTeam: elif user_role.role == PublisherUserRole.CourseTeam:
self.change_owner_role(PublisherUserRole.MarketingReviewer) self.change_owner_role(PublisherUserRole.MarketingReviewer)
if is_notifications_enabled: if is_notifications_enabled:
emails.send_email_for_seo_review(self.course) emails.send_email_for_seo_review(self.course, site)
self.review() self.review()
if is_notifications_enabled: if is_notifications_enabled:
emails.send_email_for_send_for_review(self.course, user) emails.send_email_for_send_for_review(self.course, user, site)
elif state == CourseStateChoices.Approved: elif state == CourseStateChoices.Approved:
user_role = self.course.course_user_roles.get(user=user) user_role = self.course.course_user_roles.get(user=user)
...@@ -646,7 +646,7 @@ class CourseState(TimeStampedModel, ChangedByMixin): ...@@ -646,7 +646,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
self.approved() self.approved()
if is_notifications_enabled: if is_notifications_enabled:
emails.send_email_for_mark_as_reviewed(self.course, user) emails.send_email_for_mark_as_reviewed(self.course, user, site)
self.save() self.save()
...@@ -744,10 +744,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin): ...@@ -744,10 +744,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
pass pass
@transition(field=name, source=CourseRunStateChoices.Approved, target=CourseRunStateChoices.Published) @transition(field=name, source=CourseRunStateChoices.Approved, target=CourseRunStateChoices.Published)
def published(self): def published(self, site):
emails.send_course_run_published_email(self.course_run) emails.send_course_run_published_email(self.course_run, site)
def change_state(self, state, user): def change_state(self, state, user, site=None):
""" """
Change course run workflow state and ownership also send emails if required. Change course run workflow state and ownership also send emails if required.
""" """
...@@ -763,7 +763,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin): ...@@ -763,7 +763,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.review() self.review()
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_send_for_review_course_run(self.course_run, user) emails.send_email_for_send_for_review_course_run(self.course_run, user, site)
elif state == CourseRunStateChoices.Approved: elif state == CourseRunStateChoices.Approved:
user_role = self.course_run.course.course_user_roles.get(user=user) user_role = self.course_run.course.course_user_roles.get(user=user)
...@@ -772,11 +772,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin): ...@@ -772,11 +772,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self.approved() self.approved()
if waffle.switch_is_active('enable_publisher_email_notifications'): if waffle.switch_is_active('enable_publisher_email_notifications'):
emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user) emails.send_email_for_mark_as_reviewed_course_run(self.course_run, user, site)
emails.send_email_to_publisher(self.course_run, user) emails.send_email_to_publisher(self.course_run, user, site)
elif state == CourseRunStateChoices.Published: elif state == CourseRunStateChoices.Published:
self.published() self.published(site)
self.save() self.save()
......
...@@ -4,6 +4,7 @@ from django.test import TestCase ...@@ -4,6 +4,7 @@ from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from guardian.shortcuts import get_group_perms from guardian.shortcuts import get_group_perms
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory
from course_discovery.apps.publisher.choices import PublisherUserRole from course_discovery.apps.publisher.choices import PublisherUserRole
...@@ -18,7 +19,7 @@ USER_PASSWORD = 'password' ...@@ -18,7 +19,7 @@ USER_PASSWORD = 'password'
# pylint: disable=no-member # pylint: disable=no-member
class AdminTests(TestCase): class AdminTests(SiteMixin, TestCase):
""" Tests Admin page.""" """ Tests Admin page."""
def setUp(self): def setUp(self):
...@@ -81,7 +82,7 @@ class AdminTests(TestCase): ...@@ -81,7 +82,7 @@ class AdminTests(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
class OrganizationExtensionAdminTests(TestCase): class OrganizationExtensionAdminTests(SiteMixin, TestCase):
""" Tests for OrganizationExtensionAdmin.""" """ Tests for OrganizationExtensionAdmin."""
def setUp(self): def setUp(self):
...@@ -134,7 +135,7 @@ class OrganizationExtensionAdminTests(TestCase): ...@@ -134,7 +135,7 @@ class OrganizationExtensionAdminTests(TestCase):
@ddt.ddt @ddt.ddt
class OrganizationUserRoleAdminTests(TestCase): class OrganizationUserRoleAdminTests(SiteMixin, TestCase):
""" Tests for OrganizationUserRoleAdmin.""" """ Tests for OrganizationUserRoleAdmin."""
def setUp(self): def setUp(self):
......
...@@ -6,7 +6,7 @@ from django.urls import reverse ...@@ -6,7 +6,7 @@ from django.urls import reverse
from django_fsm import TransitionNotAllowed from django_fsm import TransitionNotAllowed
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import PartnerFactory, SiteFactory, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory from course_discovery.apps.course_metadata.tests.factories import OrganizationFactory, PersonFactory
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
...@@ -525,6 +525,8 @@ class CourseStateTests(TestCase): ...@@ -525,6 +525,8 @@ class CourseStateTests(TestCase):
def setUp(self): def setUp(self):
super(CourseStateTests, self).setUp() super(CourseStateTests, self).setUp()
self.site = SiteFactory()
self.partner = PartnerFactory(site=self.site)
self.course = self.course_state.course self.course = self.course_state.course
self.course.image = make_image_file('test_banner.jpg') self.course.image = make_image_file('test_banner.jpg')
self.course.save() self.course.save()
...@@ -548,7 +550,7 @@ class CourseStateTests(TestCase): ...@@ -548,7 +550,7 @@ class CourseStateTests(TestCase):
""" """
self.assertNotEqual(self.course_state.name, state) self.assertNotEqual(self.course_state.name, state)
self.course_state.change_state(state=state, user=self.user) self.course_state.change_state(state=state, user=self.user, site=self.site)
self.assertEqual(self.course_state.name, state) self.assertEqual(self.course_state.name, state)
...@@ -561,7 +563,7 @@ class CourseStateTests(TestCase): ...@@ -561,7 +563,7 @@ class CourseStateTests(TestCase):
self.assertEqual(self.course_state.name, CourseStateChoices.Draft) self.assertEqual(self.course_state.name, CourseStateChoices.Draft)
with self.assertRaises(TransitionNotAllowed): with self.assertRaises(TransitionNotAllowed):
self.course_state.change_state(state=CourseStateChoices.Review, user=self.user) self.course_state.change_state(state=CourseStateChoices.Review, user=self.user, site=self.site)
def test_can_send_for_review(self): def test_can_send_for_review(self):
""" """
...@@ -673,6 +675,9 @@ class CourseRunStateTests(TestCase): ...@@ -673,6 +675,9 @@ class CourseRunStateTests(TestCase):
language_tag = LanguageTag(code='te-st', name='Test Language') language_tag = LanguageTag(code='te-st', name='Test Language')
language_tag.save() language_tag.save()
self.site = SiteFactory()
self.partner = PartnerFactory(site=self.site)
self.course_run.transcript_languages.add(language_tag) self.course_run.transcript_languages.add(language_tag)
self.course_run.language = language_tag self.course_run.language = language_tag
self.course_run.is_micromasters = True self.course_run.is_micromasters = True
...@@ -703,7 +708,7 @@ class CourseRunStateTests(TestCase): ...@@ -703,7 +708,7 @@ class CourseRunStateTests(TestCase):
Verify that we can change course-run state according to workflow. Verify that we can change course-run state according to workflow.
""" """
self.assertNotEqual(self.course_run_state.name, state) self.assertNotEqual(self.course_run_state.name, state)
self.course_run_state.change_state(state=state, user=self.user) self.course_run_state.change_state(state=state, user=self.user, site=self.site)
self.assertEqual(self.course_run_state.name, state) self.assertEqual(self.course_run_state.name, state)
def test_with_invalid_parent_course_state(self): def test_with_invalid_parent_course_state(self):
......
...@@ -394,14 +394,16 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView): ...@@ -394,14 +394,16 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView):
if latest_run and latest_run.course_run_state.name == CourseRunStateChoices.Published: if latest_run and latest_run.course_run_state.name == CourseRunStateChoices.Published:
# If latest run of this course is published send an email to Publisher and don't change state. # If latest run of this course is published send an email to Publisher and don't change state.
send_email_for_published_course_run_editing(latest_run) send_email_for_published_course_run_editing(latest_run, self.request.site)
else: else:
user_role = self.object.course_user_roles.get(user=user) user_role = self.object.course_user_roles.get(user=user)
# Change course state to draft if marketing not yet reviewed or # Change course state to draft if marketing not yet reviewed or
# if marketing person updating the course. # if marketing person updating the course.
if not self.object.course_state.marketing_reviewed or user_role.role == PublisherUserRole.MarketingReviewer: if not self.object.course_state.marketing_reviewed or user_role.role == PublisherUserRole.MarketingReviewer:
if self.object.course_state.name != CourseStateChoices.Draft: if self.object.course_state.name != CourseStateChoices.Draft:
self.object.course_state.change_state(state=CourseStateChoices.Draft, user=user) self.object.course_state.change_state(
state=CourseStateChoices.Draft, user=user, site=self.request.site
)
# Change ownership if user role not equal to owner role. # Change ownership if user role not equal to owner role.
if self.object.course_state.owner_role != user_role.role: if self.object.course_state.owner_role != user_role.role:
...@@ -599,7 +601,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView): ...@@ -599,7 +601,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView):
) )
messages.success(request, success_msg) messages.success(request, success_msg)
emails.send_email_for_course_creation(parent_course, course_run) emails.send_email_for_course_creation(parent_course, course_run, request.site)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id})) return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as error: # pylint: disable=broad-except except Exception as error: # pylint: disable=broad-except
# pylint: disable=no-member # pylint: disable=no-member
...@@ -740,10 +742,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix ...@@ -740,10 +742,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state = course_run.course_run_state course_run_state = course_run.course_run_state
if course_run_state.name not in immutable_states: if course_run_state.name not in immutable_states:
course_run_state.change_state(state=CourseStateChoices.Draft, user=user) course_run_state.change_state(state=CourseStateChoices.Draft, user=user, site=request.site)
if course_run.lms_course_id and lms_course_id != course_run.lms_course_id: if course_run.lms_course_id and lms_course_id != course_run.lms_course_id:
emails.send_email_for_studio_instance_created(course_run) emails.send_email_for_studio_instance_created(course_run, site=request.site)
# pylint: disable=no-member # pylint: disable=no-member
messages.success(request, _('Course run updated successfully.')) messages.success(request, _('Course run updated successfully.'))
...@@ -757,7 +759,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix ...@@ -757,7 +759,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state.change_owner_role(user_role) course_run_state.change_owner_role(user_role)
if CourseRunStateChoices.Published == course_run_state.name: if CourseRunStateChoices.Published == course_run_state.name:
send_email_for_published_course_run_editing(course_run) send_email_for_published_course_run_editing(course_run, request.site)
return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id})) return HttpResponseRedirect(reverse(self.success_url, kwargs={'pk': course_run.id}))
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
......
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
from django.test import TestCase from django.test import TestCase
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE
from course_discovery.apps.publisher.tests.factories import CourseRunFactory from course_discovery.apps.publisher.tests.factories import CourseRunFactory
...@@ -11,7 +12,7 @@ from course_discovery.apps.publisher_comments.models import Comments ...@@ -11,7 +12,7 @@ from course_discovery.apps.publisher_comments.models import Comments
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class PostCommentTests(TestCase): class PostCommentTests(SiteMixin, TestCase):
def generate_data(self, obj): def generate_data(self, obj):
"""Generate data for the form.""" """Generate data for the form."""
...@@ -39,7 +40,7 @@ class PostCommentTests(TestCase): ...@@ -39,7 +40,7 @@ class PostCommentTests(TestCase):
self.assertEqual(comment.user_email, generated_data['email']) self.assertEqual(comment.user_email, generated_data['email'])
class UpdateCommentTests(TestCase): class UpdateCommentTests(SiteMixin, TestCase):
def setUp(self): def setUp(self):
super(UpdateCommentTests, self).setUp() super(UpdateCommentTests, self).setUp()
......
from django.conf import settings
from django.contrib.sites.models import Site
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.publisher.tests import factories from course_discovery.apps.publisher.tests import factories
from course_discovery.apps.publisher_comments.forms import CommentsAdminForm from course_discovery.apps.publisher_comments.forms import CommentsAdminForm
from course_discovery.apps.publisher_comments.tests.factories import CommentFactory from course_discovery.apps.publisher_comments.tests.factories import CommentFactory
class AdminTests(TestCase): class AdminTests(SiteMixin, TestCase):
""" Tests Admin page and customize form.""" """ Tests Admin page and customize form."""
def setUp(self): def setUp(self):
super(AdminTests, self).setUp() super(AdminTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
self.site = Site.objects.get(pk=settings.SITE_ID)
self.course = factories.CourseFactory() self.course = factories.CourseFactory()
self.comment = CommentFactory(content_object=self.course, user=self.user, site=self.site) self.comment = CommentFactory(content_object=self.course, user=self.user, site=self.site)
......
import ddt import ddt
import mock import mock
from django.conf import settings
from django.contrib.sites.models import Site
from django.core import mail from django.core import mail
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from testfixtures import LogCapture from testfixtures import LogCapture
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.course_metadata.tests import toggle_switch from course_discovery.apps.course_metadata.tests import toggle_switch
from course_discovery.apps.publisher.choices import PublisherUserRole from course_discovery.apps.publisher.choices import PublisherUserRole
...@@ -20,7 +19,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact ...@@ -20,7 +19,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt @ddt.ddt
class CommentsEmailTests(TestCase): class CommentsEmailTests(SiteMixin, TestCase):
""" Tests for the e-mail functionality for course, course-run and seats. """ """ Tests for the e-mail functionality for course, course-run and seats. """
def setUp(self): def setUp(self):
...@@ -30,8 +29,6 @@ class CommentsEmailTests(TestCase): ...@@ -30,8 +29,6 @@ class CommentsEmailTests(TestCase):
self.user_2 = UserFactory() self.user_2 = UserFactory()
self.user_3 = UserFactory() self.user_3 = UserFactory()
self.site = Site.objects.get(pk=settings.SITE_ID)
self.organization_extension = factories.OrganizationExtensionFactory() self.organization_extension = factories.OrganizationExtensionFactory()
self.seat = factories.SeatFactory() self.seat = factories.SeatFactory()
......
...@@ -47,6 +47,7 @@ THIRD_PARTY_APPS = [ ...@@ -47,6 +47,7 @@ THIRD_PARTY_APPS = [
'django_fsm', 'django_fsm',
'storages', 'storages',
'django_comments', 'django_comments',
'django_sites_extensions',
'taggit', 'taggit',
'taggit_autosuggest', 'taggit_autosuggest',
'taggit_serializer', 'taggit_serializer',
...@@ -79,6 +80,7 @@ MIDDLEWARE_CLASSES = ( ...@@ -79,6 +80,7 @@ MIDDLEWARE_CLASSES = (
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware', 'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
'django.contrib.sites.middleware.CurrentSiteMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware',
'social_django.middleware.SocialAuthExceptionMiddleware', 'social_django.middleware.SocialAuthExceptionMiddleware',
'waffle.middleware.WaffleMiddleware', 'waffle.middleware.WaffleMiddleware',
...@@ -476,7 +478,9 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20 ...@@ -476,7 +478,9 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20
DEFAULT_PARTNER_ID = None DEFAULT_PARTNER_ID = None
# See: https://docs.djangoproject.com/en/dev/ref/settings/#site-id # See: https://docs.djangoproject.com/en/dev/ref/settings/#site-id
# edx-django-sites-extensions will fallback to this site if we cannot identify the site from the hostname.
SITE_ID = 1 SITE_ID = 1
COMMENTS_APP = 'course_discovery.apps.publisher_comments' COMMENTS_APP = 'course_discovery.apps.publisher_comments'
TAGGIT_CASE_INSENSITIVE = True TAGGIT_CASE_INSENSITIVE = True
......
...@@ -45,3 +45,7 @@ JWT_AUTH['JWT_SECRET_KEY'] = 'course-discovery-jwt-secret-key' ...@@ -45,3 +45,7 @@ JWT_AUTH['JWT_SECRET_KEY'] = 'course-discovery-jwt-secret-key'
LOGGING['handlers']['local'] = {'class': 'logging.NullHandler'} LOGGING['handlers']['local'] = {'class': 'logging.NullHandler'}
PUBLISHER_FROM_EMAIL = 'test@example.com' PUBLISHER_FROM_EMAIL = 'test@example.com'
# Set to 0 to disable edx-django-sites-extensions to retrieve
# the site from cache and risk working with outdated information.
SITE_CACHE_TTL = 0
...@@ -32,6 +32,7 @@ dry-rest-permissions==0.1.6 ...@@ -32,6 +32,7 @@ dry-rest-permissions==0.1.6
edx-auth-backends==1.1.2 edx-auth-backends==1.1.2
edx-ccx-keys==0.2.0 edx-ccx-keys==0.2.0
edx-django-release-util==0.3.1 edx-django-release-util==0.3.1
edx-django-sites-extensions==2.3.0
edx-drf-extensions==1.2.3 edx-drf-extensions==1.2.3
edx-opaque-keys==0.3.1 edx-opaque-keys==0.3.1
edx-rest-api-client==1.6.0 edx-rest-api-client==1.6.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment