Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
C
course-discovery
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
course-discovery
Commits
e251012e
Commit
e251012e
authored
May 31, 2017
by
Vedran Karacic
Committed by
Vedran Karačić
Jul 20, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Filter API data by partner
LEARNER-1119
parent
0c869dd7
Show whitespace changes
Inline
Side-by-side
Showing
46 changed files
with
436 additions
and
449 deletions
+436
-449
course_discovery/apps/api/filters.py
+1
-3
course_discovery/apps/api/serializers.py
+21
-21
course_discovery/apps/api/tests/mixins.py
+18
-0
course_discovery/apps/api/tests/test_serializers.py
+6
-4
course_discovery/apps/api/tests/test_views.py
+9
-11
course_discovery/apps/api/v1/tests/test_views/mixins.py
+6
-0
course_discovery/apps/api/v1/tests/test_views/test_affiliate_window.py
+3
-5
course_discovery/apps/api/v1/tests/test_views/test_catalogs.py
+4
-5
course_discovery/apps/api/v1/tests/test_views/test_course_runs.py
+3
-32
course_discovery/apps/api/v1/tests/test_views/test_courses.py
+10
-10
course_discovery/apps/api/v1/tests/test_views/test_organizations.py
+16
-18
course_discovery/apps/api/v1/tests/test_views/test_people.py
+6
-11
course_discovery/apps/api/v1/tests/test_views/test_program_types.py
+3
-4
course_discovery/apps/api/v1/tests/test_views/test_programs.py
+46
-42
course_discovery/apps/api/v1/tests/test_views/test_search.py
+9
-34
course_discovery/apps/api/v1/views/__init__.py
+0
-15
course_discovery/apps/api/v1/views/catalogs.py
+1
-0
course_discovery/apps/api/v1/views/course_runs.py
+3
-10
course_discovery/apps/api/v1/views/courses.py
+6
-5
course_discovery/apps/api/v1/views/organizations.py
+4
-1
course_discovery/apps/api/v1/views/people.py
+2
-3
course_discovery/apps/api/v1/views/programs.py
+3
-2
course_discovery/apps/api/v1/views/search.py
+2
-3
course_discovery/apps/core/tests/test_lookups.py
+2
-1
course_discovery/apps/core/tests/test_throttles.py
+3
-5
course_discovery/apps/core/tests/test_views.py
+4
-2
course_discovery/apps/course_metadata/tests/test_admin.py
+7
-4
course_discovery/apps/course_metadata/tests/test_lookups.py
+3
-2
course_discovery/apps/edx_catalog_extensions/api/v1/tests/test_views.py
+3
-3
course_discovery/apps/ietf_language_tags/tests/test_lookups.py
+2
-1
course_discovery/apps/publisher/api/serializers.py
+8
-6
course_discovery/apps/publisher/api/tests/test_serializers.py
+6
-6
course_discovery/apps/publisher/api/tests/test_views.py
+12
-12
course_discovery/apps/publisher/emails.py
+51
-36
course_discovery/apps/publisher/models.py
+11
-11
course_discovery/apps/publisher/tests/test_admin.py
+4
-3
course_discovery/apps/publisher/tests/test_emails.py
+47
-47
course_discovery/apps/publisher/tests/test_model.py
+9
-4
course_discovery/apps/publisher/tests/test_views.py
+58
-50
course_discovery/apps/publisher/views.py
+8
-6
course_discovery/apps/publisher_comments/api/tests/test_views.py
+3
-2
course_discovery/apps/publisher_comments/tests/test_admin.py
+2
-4
course_discovery/apps/publisher_comments/tests/test_emails.py
+2
-5
course_discovery/settings/base.py
+4
-0
course_discovery/settings/test.py
+4
-0
requirements/base.txt
+1
-0
No files found.
course_discovery/apps/api/filters.py
View file @
e251012e
import
logging
import
logging
from
django.conf
import
settings
from
django.contrib.auth
import
get_user_model
from
django.contrib.auth
import
get_user_model
from
django.db.models
import
QuerySet
from
django.db.models
import
QuerySet
from
django.utils.translation
import
ugettext
as
_
from
django.utils.translation
import
ugettext
as
_
...
@@ -13,7 +12,6 @@ from guardian.shortcuts import get_objects_for_user
...
@@ -13,7 +12,6 @@ from guardian.shortcuts import get_objects_for_user
from
rest_framework.exceptions
import
NotFound
,
PermissionDenied
from
rest_framework.exceptions
import
NotFound
,
PermissionDenied
from
course_discovery.apps.api.utils
import
cast2int
from
course_discovery.apps.api.utils
import
cast2int
from
course_discovery.apps.core.models
import
Partner
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Organization
,
Program
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Organization
,
Program
...
@@ -93,7 +91,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
...
@@ -93,7 +91,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
# Return data for the default partner, if no partner is requested
# Return data for the default partner, if no partner is requested
if
not
any
(
field
in
filters
for
field
in
(
'partner'
,
'partner_exact'
)):
if
not
any
(
field
in
filters
for
field
in
(
'partner'
,
'partner_exact'
)):
filters
[
'partner'
]
=
Partner
.
objects
.
get
(
pk
=
settings
.
DEFAULT_PARTNER_ID
)
.
short_code
filters
[
'partner'
]
=
request
.
site
.
partner
.
short_code
return
filters
return
filters
...
...
course_discovery/apps/api/serializers.py
View file @
e251012e
...
@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
...
@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
tags
=
TagListSerializerField
()
tags
=
TagListSerializerField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
):
def
prefetch_queryset
(
cls
,
partner
):
return
Organization
.
objects
.
all
(
)
.
select_related
(
'partner'
)
.
prefetch_related
(
'tags'
)
return
Organization
.
objects
.
filter
(
partner
=
partner
)
.
select_related
(
'partner'
)
.
prefetch_related
(
'tags'
)
class
Meta
(
MinimalOrganizationSerializer
.
Meta
):
class
Meta
(
MinimalOrganizationSerializer
.
Meta
):
fields
=
MinimalOrganizationSerializer
.
Meta
.
fields
+
(
fields
=
MinimalOrganizationSerializer
.
Meta
.
fields
+
(
...
@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
...
@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url
=
serializers
.
SerializerMethodField
()
marketing_url
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
):
def
prefetch_queryset
(
cls
,
partner
,
queryset
=
None
,
course_runs
=
None
):
# Explicitly check for None to avoid returning all Courses when the
# Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty.
# queryset passed in happens to be empty.
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
all
(
)
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
filter
(
partner
=
partner
)
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'expected_learning_items'
,
'expected_learning_items'
,
'prerequisites'
,
'prerequisites'
,
'subjects'
,
'subjects'
,
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
)
)
class
Meta
(
MinimalCourseSerializer
.
Meta
):
class
Meta
(
MinimalCourseSerializer
.
Meta
):
...
@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
...
@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
programs
=
serializers
.
SerializerMethodField
()
programs
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
):
def
prefetch_queryset
(
cls
,
partner
,
queryset
=
None
,
course_runs
=
None
):
"""
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
filtered CourseRun queryset.
"""
"""
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
all
(
)
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
filter
(
partner
=
partner
)
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'expected_learning_items'
,
'expected_learning_items'
,
'prerequisites'
,
'prerequisites'
,
'subjects'
,
'subjects'
,
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
)
)
def
get_course_runs
(
self
,
course
):
def
get_course_runs
(
self
,
course
):
...
@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
...
@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs
=
serializers
.
SerializerMethodField
()
course_runs
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
):
def
prefetch_queryset
(
cls
,
partner
,
queryset
=
None
,
course_runs
=
None
):
"""
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
filtered CourseRun queryset.
"""
"""
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
all
(
)
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
filter
(
partner
=
partner
)
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'expected_learning_items'
,
'expected_learning_items'
,
'prerequisites'
,
'prerequisites'
,
'subjects'
,
'subjects'
,
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
)
)
def
get_course_runs
(
self
,
course
):
def
get_course_runs
(
self
,
course
):
...
@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
...
@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
type
=
serializers
.
SlugRelatedField
(
slug_field
=
'name'
,
queryset
=
ProgramType
.
objects
.
all
())
type
=
serializers
.
SlugRelatedField
(
slug_field
=
'name'
,
queryset
=
ProgramType
.
objects
.
all
())
@classmethod
@classmethod
def
prefetch_queryset
(
cls
):
def
prefetch_queryset
(
cls
,
partner
):
return
Program
.
objects
.
all
(
)
.
select_related
(
'type'
,
'partner'
)
.
prefetch_related
(
return
Program
.
objects
.
filter
(
partner
=
partner
)
.
select_related
(
'type'
,
'partner'
)
.
prefetch_related
(
'excluded_course_runs'
,
'excluded_course_runs'
,
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
...
@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
...
@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
applicable_seat_types
=
serializers
.
SerializerMethodField
()
applicable_seat_types
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
):
def
prefetch_queryset
(
cls
,
partner
):
"""
"""
Prefetch the related objects that will be serialized with a `Program`.
Prefetch the related objects that will be serialized with a `Program`.
...
@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
...
@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
chain of related fields from programs to course runs (i.e., we want control over
chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching).
the querysets that we're prefetching).
"""
"""
return
Program
.
objects
.
all
(
)
.
select_related
(
'type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
Program
.
objects
.
filter
(
partner
=
partner
)
.
select_related
(
'type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'excluded_course_runs'
,
'excluded_course_runs'
,
'expected_learning_items'
,
'expected_learning_items'
,
'faq'
,
'faq'
,
...
@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
...
@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
'type__applicable_seat_types'
,
'type__applicable_seat_types'
,
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
Prefetch
(
'courses'
,
queryset
=
CourseSerializer
.
prefetch_queryset
()),
Prefetch
(
'courses'
,
queryset
=
CourseSerializer
.
prefetch_queryset
(
partner
=
partner
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'credit_backing_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'credit_backing_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'corporate_endorsements'
,
queryset
=
CorporateEndorsementSerializer
.
prefetch_queryset
()),
Prefetch
(
'corporate_endorsements'
,
queryset
=
CorporateEndorsementSerializer
.
prefetch_queryset
()),
Prefetch
(
'individual_endorsements'
,
queryset
=
EndorsementSerializer
.
prefetch_queryset
()),
Prefetch
(
'individual_endorsements'
,
queryset
=
EndorsementSerializer
.
prefetch_queryset
()),
)
)
...
...
course_discovery/apps/api/tests/mixins.py
0 → 100644
View file @
e251012e
from
django.conf
import
settings
from
django.contrib.sites.models
import
Site
from
django.test
import
RequestFactory
from
course_discovery.apps.core.tests.factories
import
PartnerFactory
,
SiteFactory
class
SiteMixin
(
object
):
def
setUp
(
self
):
super
(
SiteMixin
,
self
)
.
setUp
()
domain
=
'testserver.fake'
self
.
client
=
self
.
client_class
(
SERVER_NAME
=
domain
)
Site
.
objects
.
all
()
.
delete
()
self
.
site
=
SiteFactory
(
id
=
settings
.
SITE_ID
,
domain
=
domain
)
self
.
partner
=
PartnerFactory
(
site
=
self
.
site
)
self
.
request
=
RequestFactory
(
SERVER_NAME
=
self
.
site
.
domain
)
.
get
(
''
)
self
.
request
.
site
=
self
.
site
course_discovery/apps/api/tests/test_serializers.py
View file @
e251012e
...
@@ -21,6 +21,7 @@ from course_discovery.apps.api.serializers import (
...
@@ -21,6 +21,7 @@ from course_discovery.apps.api.serializers import (
ProgramSerializer
,
ProgramTypeSerializer
,
SeatSerializer
,
SubjectSerializer
,
TypeaheadCourseRunSearchSerializer
,
ProgramSerializer
,
ProgramTypeSerializer
,
SeatSerializer
,
SubjectSerializer
,
TypeaheadCourseRunSearchSerializer
,
TypeaheadProgramSearchSerializer
,
VideoSerializer
TypeaheadProgramSearchSerializer
,
VideoSerializer
)
)
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
...
@@ -96,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
...
@@ -96,7 +97,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
self
.
assertEqual
(
User
.
objects
.
filter
(
username
=
username
)
.
count
(),
0
)
# pylint: disable=no-member
self
.
assertEqual
(
User
.
objects
.
filter
(
username
=
username
)
.
count
(),
0
)
# pylint: disable=no-member
class
MinimalCourseSerializerTests
(
TestCase
):
class
MinimalCourseSerializerTests
(
SiteMixin
,
TestCase
):
serializer_class
=
MinimalCourseSerializer
serializer_class
=
MinimalCourseSerializer
def
get_expected_data
(
self
,
course
,
request
):
def
get_expected_data
(
self
,
course
,
request
):
...
@@ -114,8 +115,8 @@ class MinimalCourseSerializerTests(TestCase):
...
@@ -114,8 +115,8 @@ class MinimalCourseSerializerTests(TestCase):
def
test_data
(
self
):
def
test_data
(
self
):
request
=
make_request
()
request
=
make_request
()
organizations
=
OrganizationFactory
()
organizations
=
OrganizationFactory
(
partner
=
self
.
partner
)
course
=
CourseFactory
(
authoring_organizations
=
[
organizations
])
course
=
CourseFactory
(
authoring_organizations
=
[
organizations
]
,
partner
=
self
.
partner
)
CourseRunFactory
.
create_batch
(
2
,
course
=
course
)
CourseRunFactory
.
create_batch
(
2
,
course
=
course
)
serializer
=
self
.
serializer_class
(
course
,
context
=
{
'request'
:
request
})
serializer
=
self
.
serializer_class
(
course
,
context
=
{
'request'
:
request
})
expected
=
self
.
get_expected_data
(
course
,
request
)
expected
=
self
.
get_expected_data
(
course
,
request
)
...
@@ -178,9 +179,10 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
...
@@ -178,9 +179,10 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
def
setUp
(
self
):
def
setUp
(
self
):
super
()
.
setUp
()
super
()
.
setUp
()
self
.
request
=
make_request
()
self
.
request
=
make_request
()
self
.
course
=
CourseFactory
()
self
.
course
=
CourseFactory
(
partner
=
self
.
partner
)
self
.
deleted_program
=
ProgramFactory
(
self
.
deleted_program
=
ProgramFactory
(
courses
=
[
self
.
course
],
courses
=
[
self
.
course
],
partner
=
self
.
partner
,
status
=
ProgramStatus
.
Deleted
status
=
ProgramStatus
.
Deleted
)
)
...
...
course_discovery/apps/api/tests/test_views.py
View file @
e251012e
import
ddt
import
ddt
import
pytest
from
django.conf
import
settings
from
django.contrib.auth.models
import
AnonymousUser
from
django.contrib.auth.models
import
AnonymousUser
from
django.core.exceptions
import
PermissionDenied
from
django.core.exceptions
import
PermissionDenied
from
django.test
import
RequestFactory
,
TestCase
from
django.test
import
RequestFactory
,
TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
from
course_discovery.apps.api.views
import
api_docs_permission_denied_handler
from
course_discovery.apps.api.views
import
api_docs_permission_denied_handler
from
course_discovery.apps.core.tests.factories
import
PartnerFactory
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
@pytest.mark.django_db
class
TestApiDocs
(
APITestCase
):
class
TestApiDocs
:
"""
"""
Regression tests introduced following LEARNER-1590.
Regression tests introduced following LEARNER-1590.
"""
"""
path
=
reverse
(
'api_docs'
)
path
=
reverse
(
'api_docs'
)
def
test_api_docs
(
self
,
admin_client
):
def
test_api_docs
(
self
):
"""
"""
Verify that the API docs are available to authenticated clients.
Verify that the API docs are available to authenticated clients.
"""
"""
PartnerFactory
(
pk
=
settings
.
DEFAULT_PARTNER_ID
)
user
=
UserFactory
(
is_staff
=
True
)
self
.
client
.
login
(
username
=
user
.
username
,
password
=
USER_PASSWORD
)
response
=
admin_client
.
get
(
self
.
path
)
response
=
self
.
client
.
get
(
self
.
path
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
def
test_api_docs_redirect
(
self
,
client
):
def
test_api_docs_redirect
(
self
):
"""
"""
Verify that unauthenticated clients are redirected.
Verify that unauthenticated clients are redirected.
"""
"""
response
=
client
.
get
(
self
.
path
)
response
=
self
.
client
.
get
(
self
.
path
)
assert
response
.
status_code
==
302
assert
response
.
status_code
==
302
...
...
course_discovery/apps/api/v1/tests/test_views/mixins.py
View file @
e251012e
...
@@ -4,6 +4,7 @@ import json
...
@@ -4,6 +4,7 @@ import json
import
responses
import
responses
from
django.conf
import
settings
from
django.conf
import
settings
from
rest_framework.test
import
APITestCase
as
RestAPITestCase
from
rest_framework.test
import
APIRequestFactory
from
rest_framework.test
import
APIRequestFactory
from
course_discovery.apps.api.serializers
import
(
from
course_discovery.apps.api.serializers
import
(
...
@@ -11,6 +12,7 @@ from course_discovery.apps.api.serializers import (
...
@@ -11,6 +12,7 @@ from course_discovery.apps.api.serializers import (
CourseWithProgramsSerializer
,
FlattenedCourseRunWithCourseSerializer
,
MinimalProgramSerializer
,
CourseWithProgramsSerializer
,
FlattenedCourseRunWithCourseSerializer
,
MinimalProgramSerializer
,
OrganizationSerializer
,
PersonSerializer
,
ProgramSerializer
,
ProgramTypeSerializer
OrganizationSerializer
,
PersonSerializer
,
ProgramSerializer
,
ProgramTypeSerializer
)
)
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
class
SerializationMixin
(
object
):
class
SerializationMixin
(
object
):
...
@@ -88,3 +90,7 @@ class OAuth2Mixin(object):
...
@@ -88,3 +90,7 @@ class OAuth2Mixin(object):
content_type
=
'application/json'
,
content_type
=
'application/json'
,
status
=
status
status
=
status
)
)
class
APITestCase
(
SiteMixin
,
RestAPITestCase
):
pass
course_discovery/apps/api/v1/tests/test_views/test_affiliate_window.py
View file @
e251012e
...
@@ -7,10 +7,9 @@ import ddt
...
@@ -7,10 +7,9 @@ import ddt
import
pytz
import
pytz
from
lxml
import
etree
from
lxml
import
etree
from
rest_framework.reverse
import
reverse
from
rest_framework.reverse
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.serializers
import
AffiliateWindowSerializer
from
course_discovery.apps.api.serializers
import
AffiliateWindowSerializer
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
...
@@ -46,7 +45,6 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
...
@@ -46,7 +45,6 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
def
test_affiliate_with_supported_seats
(
self
):
def
test_affiliate_with_supported_seats
(
self
):
""" Verify that endpoint returns course runs for verified and professional seats only. """
""" Verify that endpoint returns course runs for verified and professional seats only. """
with
self
.
assertNumQueries
(
8
):
response
=
self
.
client
.
get
(
self
.
affiliate_url
)
response
=
self
.
client
.
get
(
self
.
affiliate_url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -130,7 +128,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
...
@@ -130,7 +128,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
# Superusers can view all catalogs
# Superusers can view all catalogs
self
.
client
.
force_authenticate
(
superuser
)
self
.
client
.
force_authenticate
(
superuser
)
with
self
.
assertNumQueries
(
4
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -140,7 +138,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
...
@@ -140,7 +138,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
self
.
assertEqual
(
response
.
status_code
,
403
)
self
.
assertEqual
(
response
.
status_code
,
403
)
catalog
.
viewers
=
[
self
.
user
]
catalog
.
viewers
=
[
self
.
user
]
with
self
.
assertNumQueries
(
7
):
with
self
.
assertNumQueries
(
8
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
...
course_discovery/apps/api/v1/tests/test_views/test_catalogs.py
View file @
e251012e
...
@@ -8,10 +8,9 @@ import pytz
...
@@ -8,10 +8,9 @@ import pytz
import
responses
import
responses
from
django.contrib.auth
import
get_user_model
from
django.contrib.auth
import
get_user_model
from
rest_framework.reverse
import
reverse
from
rest_framework.reverse
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.tests.jwt_utils
import
generate_jwt_header_for_user
from
course_discovery.apps.api.tests.jwt_utils
import
generate_jwt_header_for_user
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
OAuth2Mixin
,
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
OAuth2Mixin
,
SerializationMixin
from
course_discovery.apps.catalogs.models
import
Catalog
from
course_discovery.apps.catalogs.models
import
Catalog
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
...
@@ -31,6 +30,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
...
@@ -31,6 +30,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CatalogViewSetTests
,
self
)
.
setUp
()
super
(
CatalogViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
request
.
user
=
self
.
user
self
.
client
.
force_authenticate
(
self
.
user
)
self
.
client
.
force_authenticate
(
self
.
user
)
self
.
catalog
=
CatalogFactory
(
query
=
'title:abc*'
)
self
.
catalog
=
CatalogFactory
(
query
=
'title:abc*'
)
enrollment_end
=
datetime
.
datetime
.
now
(
pytz
.
UTC
)
+
datetime
.
timedelta
(
days
=
30
)
enrollment_end
=
datetime
.
datetime
.
now
(
pytz
.
UTC
)
+
datetime
.
timedelta
(
days
=
30
)
...
@@ -172,7 +172,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
...
@@ -172,7 +172,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# to be included.
# to be included.
filtered_course_run
=
CourseRunFactory
(
course
=
course
)
filtered_course_run
=
CourseRunFactory
(
course
=
course
)
with
self
.
assertNumQueries
(
1
6
):
with
self
.
assertNumQueries
(
1
8
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
@@ -185,7 +185,6 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
...
@@ -185,7 +185,6 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# Any course appearing in the response must have at least one serialized run.
# Any course appearing in the response must have at least one serialized run.
assert
len
(
response
.
data
[
'results'
][
0
][
'course_runs'
])
>
0
assert
len
(
response
.
data
[
'results'
][
0
][
'course_runs'
])
>
0
else
:
else
:
with
self
.
assertNumQueries
(
2
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
@@ -218,7 +217,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
...
@@ -218,7 +217,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url
=
reverse
(
'api:v1:catalog-csv'
,
kwargs
=
{
'id'
:
self
.
catalog
.
id
})
url
=
reverse
(
'api:v1:catalog-csv'
,
kwargs
=
{
'id'
:
self
.
catalog
.
id
})
with
self
.
assertNumQueries
(
1
7
):
with
self
.
assertNumQueries
(
1
8
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
course_run
=
self
.
serialize_catalog_flat_course_run
(
self
.
course_run
)
course_run
=
self
.
serialize_catalog_flat_course_run
(
self
.
course_run
)
...
...
course_discovery/apps/api/v1/tests/test_views/test_course_runs.py
View file @
e251012e
...
@@ -4,19 +4,16 @@ import urllib
...
@@ -4,19 +4,16 @@ import urllib
import
ddt
import
ddt
import
pytz
import
pytz
from
django.conf
import
settings
from
django.db.models.functions
import
Lower
from
django.db.models.functions
import
Lower
from
rest_framework.reverse
import
reverse
from
rest_framework.reverse
import
reverse
from
rest_framework.test
import
APIRequestFactory
,
APITestCase
from
rest_framework.test
import
APIRequestFactory
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
CourseRun
from
course_discovery.apps.course_metadata.models
import
CourseRun
from
course_discovery.apps.course_metadata.tests.factories
import
(
from
course_discovery.apps.course_metadata.tests.factories
import
CourseRunFactory
,
ProgramFactory
,
SeatFactory
CourseRunFactory
,
PartnerFactory
,
ProgramFactory
,
SeatFactory
)
@ddt.ddt
@ddt.ddt
...
@@ -25,10 +22,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
...
@@ -25,10 +22,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
super
(
CourseRunViewSetTests
,
self
)
.
setUp
()
super
(
CourseRunViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
client
.
force_authenticate
(
self
.
user
)
self
.
client
.
force_authenticate
(
self
.
user
)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self
.
partner
=
PartnerFactory
(
id
=
settings
.
DEFAULT_PARTNER_ID
)
self
.
course_run
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
course_run
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
course_run_2
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
course_run_2
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
refresh_index
()
self
.
refresh_index
()
...
@@ -170,15 +163,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
...
@@ -170,15 +163,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
key
=
lambda
course_run
:
course_run
[
'key'
])
key
=
lambda
course_run
:
course_run
[
'key'
])
self
.
assertListEqual
(
actual_sorted
,
expected_sorted
)
self
.
assertListEqual
(
actual_sorted
,
expected_sorted
)
def
test_list_query_invalid_partner
(
self
):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query
=
'title:Some random title'
url
=
'{root}?q={query}&partner={partner}'
.
format
(
root
=
reverse
(
'api:v1:course_run-list'
),
query
=
query
,
partner
=
'foo'
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
400
)
def
assert_list_results
(
self
,
url
,
expected
,
extra_context
=
None
):
def
assert_list_results
(
self
,
url
,
expected
,
extra_context
=
None
):
expected
=
sorted
(
expected
,
key
=
lambda
course_run
:
course_run
.
key
.
lower
())
expected
=
sorted
(
expected
,
key
=
lambda
course_run
:
course_run
.
key
.
lower
())
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
...
@@ -256,7 +240,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
...
@@ -256,7 +240,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
'course_run_ids'
:
self
.
course_run
.
key
,
'course_run_ids'
:
self
.
course_run
.
key
,
})
})
url
=
'{}?{}'
.
format
(
reverse
(
'api:v1:course_run-contains'
),
qs
)
url
=
'{}?{}'
.
format
(
reverse
(
'api:v1:course_run-contains'
),
qs
)
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
self
.
assertEqual
(
self
.
assertEqual
(
...
@@ -268,18 +251,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
...
@@ -268,18 +251,6 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
}
}
)
)
def
test_contains_single_course_run_invalid_partner
(
self
):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs
=
urllib
.
parse
.
urlencode
({
'query'
:
'id:course*'
,
'course_run_ids'
:
self
.
course_run
.
key
,
'partner'
:
'foo'
})
url
=
'{}?{}'
.
format
(
reverse
(
'api:v1:course_run-contains'
),
qs
)
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
400
def
test_contains_multiple_course_runs
(
self
):
def
test_contains_multiple_course_runs
(
self
):
qs
=
urllib
.
parse
.
urlencode
({
qs
=
urllib
.
parse
.
urlencode
({
'query'
:
'id:course*'
,
'query'
:
'id:course*'
,
...
...
course_discovery/apps/api/v1/tests/test_views/test_courses.py
View file @
e251012e
...
@@ -4,9 +4,8 @@ import ddt
...
@@ -4,9 +4,8 @@ import ddt
import
pytz
import
pytz
from
django.db.models.functions
import
Lower
from
django.db.models.functions
import
Lower
from
rest_framework.reverse
import
reverse
from
rest_framework.reverse
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
Course
from
course_discovery.apps.course_metadata.models
import
Course
...
@@ -22,14 +21,15 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -22,14 +21,15 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CourseViewSetTests
,
self
)
.
setUp
()
super
(
CourseViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
request
.
user
=
self
.
user
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
course
=
CourseFactory
()
self
.
course
=
CourseFactory
(
partner
=
self
.
partner
)
def
test_get
(
self
):
def
test_get
(
self
):
""" Verify the endpoint returns the details for a single course. """
""" Verify the endpoint returns the details for a single course. """
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
with
self
.
assertNumQueries
(
18
):
with
self
.
assertNumQueries
(
20
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
data
,
self
.
serialize_course
(
self
.
course
))
self
.
assertEqual
(
response
.
data
,
self
.
serialize_course
(
self
.
course
))
...
@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -38,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """
""" Verify the endpoint returns no deleted associated programs """
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
with
self
.
assertNumQueries
(
1
1
):
with
self
.
assertNumQueries
(
1
3
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
data
.
get
(
'programs'
),
[])
self
.
assertEqual
(
response
.
data
.
get
(
'programs'
),
[])
...
@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -51,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
+=
'?include_deleted_programs=1'
url
+=
'?include_deleted_programs=1'
with
self
.
assertNumQueries
(
2
2
):
with
self
.
assertNumQueries
(
2
4
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
self
.
assertEqual
(
...
@@ -187,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -187,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
""" Verify the endpoint returns a list of all courses. """
url
=
reverse
(
'api:v1:course-list'
)
url
=
reverse
(
'api:v1:course-list'
)
with
self
.
assertNumQueries
(
2
4
):
with
self
.
assertNumQueries
(
2
6
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertListEqual
(
self
.
assertListEqual
(
...
@@ -203,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -203,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query
=
'title:'
+
title
query
=
'title:'
+
title
url
=
'{root}?q={query}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
query
=
query
)
url
=
'{root}?q={query}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
query
=
query
)
with
self
.
assertNumQueries
(
3
7
):
with
self
.
assertNumQueries
(
3
9
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
def
test_list_key_filter
(
self
):
def
test_list_key_filter
(
self
):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
courses
=
CourseFactory
.
create_batch
(
3
)
courses
=
CourseFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
courses
=
sorted
(
courses
,
key
=
lambda
course
:
course
.
key
.
lower
())
courses
=
sorted
(
courses
,
key
=
lambda
course
:
course
.
key
.
lower
())
keys
=
','
.
join
([
course
.
key
for
course
in
courses
])
keys
=
','
.
join
([
course
.
key
for
course
in
courses
])
url
=
'{root}?keys={keys}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
keys
=
keys
)
url
=
'{root}?keys={keys}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
keys
=
keys
)
with
self
.
assertNumQueries
(
3
7
):
with
self
.
assertNumQueries
(
3
9
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
...
...
course_discovery/apps/api/v1/tests/test_views/test_organizations.py
View file @
e251012e
import
uuid
import
uuid
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.course_metadata.tests.factories
import
Organization
,
OrganizationFactory
from
course_discovery.apps.course_metadata.tests.factories
import
Organization
,
OrganizationFactory
...
@@ -14,6 +13,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -14,6 +13,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
OrganizationViewSetTests
,
self
)
.
setUp
()
super
(
OrganizationViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
request
.
user
=
self
.
user
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
def
test_authentication
(
self
):
def
test_authentication
(
self
):
...
@@ -27,17 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -27,17 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
assert_response_data_valid
(
self
,
response
,
organizations
,
many
=
True
):
def
assert_response_data_valid
(
self
,
response
,
organizations
,
many
=
True
):
""" Asserts the response data (only) contains the expected organizations. """
""" Asserts the response data (only) contains the expected organizations. """
actual
=
response
.
data
actual
=
response
.
data
serializer_data
=
self
.
serialize_organization
(
organizations
,
many
=
many
)
if
many
:
if
many
:
actual
=
actual
[
'results'
]
actual
=
actual
[
'results'
]
self
.
assert
Equal
(
actual
,
self
.
serialize_organization
(
organizations
,
many
=
many
)
)
self
.
assert
CountEqual
(
actual
,
serializer_data
)
def
assert_list_uuid_filter
(
self
,
organizations
):
def
assert_list_uuid_filter
(
self
,
organizations
,
expected_query_count
):
""" Asserts the list endpoint supports filtering by UUID. """
""" Asserts the list endpoint supports filtering by UUID. """
organizations
=
sorted
(
organizations
,
key
=
lambda
o
:
o
.
created
)
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
expected_query_count
):
uuids
=
','
.
join
([
organization
.
uuid
.
hex
for
organization
in
organizations
])
uuids
=
','
.
join
([
organization
.
uuid
.
hex
for
organization
in
organizations
])
url
=
'{root}?uuids={uuids}'
.
format
(
root
=
self
.
list_path
,
uuids
=
uuids
)
url
=
'{root}?uuids={uuids}'
.
format
(
root
=
self
.
list_path
,
uuids
=
uuids
)
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
...
@@ -45,9 +45,8 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -45,9 +45,8 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assert_response_data_valid
(
response
,
organizations
)
self
.
assert_response_data_valid
(
response
,
organizations
)
def
assert_list_tag_filter
(
self
,
organizations
,
tags
,
expected_query_count
=
5
):
def
assert_list_tag_filter
(
self
,
organizations
,
tags
,
expected_query_count
=
7
):
""" Asserts the list endpoint supports filtering by tags. """
""" Asserts the list endpoint supports filtering by tags. """
with
self
.
assertNumQueries
(
expected_query_count
):
with
self
.
assertNumQueries
(
expected_query_count
):
tags
=
','
.
join
(
tags
)
tags
=
','
.
join
(
tags
)
url
=
'{root}?tags={tags}'
.
format
(
root
=
self
.
list_path
,
tags
=
tags
)
url
=
'{root}?tags={tags}'
.
format
(
root
=
self
.
list_path
,
tags
=
tags
)
...
@@ -58,10 +57,9 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -58,10 +57,9 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
test_list
(
self
):
def
test_list
(
self
):
""" Verify the endpoint returns a list of all organizations. """
""" Verify the endpoint returns a list of all organizations. """
OrganizationFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
OrganizationFactory
.
create_batch
(
3
)
with
self
.
assertNumQueries
(
7
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
self
.
list_path
)
response
=
self
.
client
.
get
(
self
.
list_path
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -70,22 +68,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -70,22 +68,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
test_list_uuid_filter
(
self
):
def
test_list_uuid_filter
(
self
):
""" Verify the endpoint returns a list of organizations filtered by UUID. """
""" Verify the endpoint returns a list of organizations filtered by UUID. """
organizations
=
OrganizationFactory
.
create_batch
(
3
)
organizations
=
OrganizationFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
# Test with a single UUID
# Test with a single UUID
self
.
assert_list_uuid_filter
([
organizations
[
0
]])
self
.
assert_list_uuid_filter
([
organizations
[
0
]]
,
7
)
# Test with multiple UUIDs
# Test with multiple UUIDs
self
.
assert_list_uuid_filter
(
organizations
)
self
.
assert_list_uuid_filter
(
organizations
,
7
)
def
test_list_tag_filter
(
self
):
def
test_list_tag_filter
(
self
):
""" Verify the endpoint returns a list of organizations filtered by tag. """
""" Verify the endpoint returns a list of organizations filtered by tag. """
tag
=
'test-org'
tag
=
'test-org'
organizations
=
OrganizationFactory
.
create_batch
(
2
)
organizations
=
OrganizationFactory
.
create_batch
(
2
,
partner
=
self
.
partner
)
# If no organizations have been tagged, the endpoint should not return any data
# If no organizations have been tagged, the endpoint should not return any data
self
.
assert_list_tag_filter
([],
[
tag
],
expected_query_count
=
3
)
self
.
assert_list_tag_filter
([],
[
tag
],
expected_query_count
=
5
)
# Tagged organizations should be returned
# Tagged organizations should be returned
organizations
[
0
]
.
tags
.
add
(
tag
)
organizations
[
0
]
.
tags
.
add
(
tag
)
...
@@ -99,7 +97,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -99,7 +97,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
test_retrieve
(
self
):
def
test_retrieve
(
self
):
""" Verify the endpoint returns details for a single organization. """
""" Verify the endpoint returns details for a single organization. """
organization
=
OrganizationFactory
()
organization
=
OrganizationFactory
(
partner
=
self
.
partner
)
url
=
reverse
(
'api:v1:organization-detail'
,
kwargs
=
{
'uuid'
:
organization
.
uuid
})
url
=
reverse
(
'api:v1:organization-detail'
,
kwargs
=
{
'uuid'
:
organization
.
uuid
})
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
...
...
course_discovery/apps/api/v1/tests/test_views/test_people.py
View file @
e251012e
# pylint: disable=redefined-builtin,no-member
# pylint: disable=redefined-builtin,no-member
import
ddt
import
ddt
from
django.conf
import
settings
from
django.contrib.auth
import
get_user_model
from
django.contrib.auth
import
get_user_model
from
django.db
import
IntegrityError
from
django.db
import
IntegrityError
from
mock
import
mock
from
mock
import
mock
...
@@ -8,20 +7,20 @@ from rest_framework.reverse import reverse
...
@@ -8,20 +7,20 @@ from rest_framework.reverse import reverse
from
rest_framework.test
import
APITestCase
from
rest_framework.test
import
APITestCase
from
testfixtures
import
LogCapture
from
testfixtures
import
LogCapture
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.views.people
import
logger
as
people_logger
from
course_discovery.apps.api.v1.views.people
import
logger
as
people_logger
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.course_metadata.models
import
Person
from
course_discovery.apps.course_metadata.models
import
Person
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests.factories
import
(
OrganizationFactory
,
PartnerFactory
,
PersonFactory
,
from
course_discovery.apps.course_metadata.tests.factories
import
OrganizationFactory
,
PersonFactory
,
PositionFactory
PositionFactory
)
User
=
get_user_model
()
User
=
get_user_model
()
@ddt.ddt
@ddt.ddt
class
PersonViewSetTests
(
SerializationMixin
,
APITestCase
):
class
PersonViewSetTests
(
SerializationMixin
,
SiteMixin
,
APITestCase
):
""" Tests for the person resource. """
""" Tests for the person resource. """
people_list_url
=
reverse
(
'api:v1:person-list'
)
people_list_url
=
reverse
(
'api:v1:person-list'
)
...
@@ -29,13 +28,9 @@ class PersonViewSetTests(SerializationMixin, APITestCase):
...
@@ -29,13 +28,9 @@ class PersonViewSetTests(SerializationMixin, APITestCase):
super
(
PersonViewSetTests
,
self
)
.
setUp
()
super
(
PersonViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
client
.
force_authenticate
(
self
.
user
)
self
.
client
.
force_authenticate
(
self
.
user
)
self
.
person
=
PersonFactory
()
self
.
person
=
PersonFactory
(
partner
=
self
.
partner
)
PositionFactory
(
person
=
self
.
person
)
self
.
organization
=
OrganizationFactory
(
partner
=
self
.
partner
)
self
.
organization
=
OrganizationFactory
()
PositionFactory
(
person
=
self
.
person
,
organization
=
self
.
organization
)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self
.
partner
=
PartnerFactory
(
id
=
settings
.
DEFAULT_PARTNER_ID
)
toggle_switch
(
'publish_person_to_marketing_site'
,
True
)
toggle_switch
(
'publish_person_to_marketing_site'
,
True
)
self
.
expected_node
=
{
self
.
expected_node
=
{
'resource'
:
'node'
,
''
'resource'
:
'node'
,
''
...
...
course_discovery/apps/api/v1/tests/test_views/test_program_types.py
View file @
e251012e
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.course_metadata.models
import
ProgramType
from
course_discovery.apps.course_metadata.models
import
ProgramType
from
course_discovery.apps.course_metadata.tests.factories
import
ProgramTypeFactory
from
course_discovery.apps.course_metadata.tests.factories
import
ProgramTypeFactory
...
@@ -28,7 +27,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
...
@@ -28,7 +27,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all program types. """
""" Verify the endpoint returns a list of all program types. """
ProgramTypeFactory
.
create_batch
(
4
)
ProgramTypeFactory
.
create_batch
(
4
)
expected
=
ProgramType
.
objects
.
all
()
expected
=
ProgramType
.
objects
.
all
()
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
self
.
list_path
)
response
=
self
.
client
.
get
(
self
.
list_path
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
@@ -39,7 +38,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
...
@@ -39,7 +38,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
program_type
=
ProgramTypeFactory
()
program_type
=
ProgramTypeFactory
()
url
=
reverse
(
'api:v1:program_type-detail'
,
kwargs
=
{
'slug'
:
program_type
.
slug
})
url
=
reverse
(
'api:v1:program_type-detail'
,
kwargs
=
{
'slug'
:
program_type
.
slug
})
with
self
.
assertNumQueries
(
4
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
...
course_discovery/apps/api/v1/tests/test_views/test_programs.py
View file @
e251012e
...
@@ -3,10 +3,9 @@ import urllib.parse
...
@@ -3,10 +3,9 @@ import urllib.parse
import
ddt
import
ddt
from
django.core.cache
import
cache
from
django.core.cache
import
cache
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.serializers
import
MinimalProgramSerializer
from
course_discovery.apps.api.serializers
import
MinimalProgramSerializer
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.api.v1.views.programs
import
ProgramViewSet
from
course_discovery.apps.api.v1.views.programs
import
ProgramViewSet
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.core.tests.helpers
import
make_image_file
...
@@ -25,16 +24,17 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -25,16 +24,17 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
ProgramViewSetTests
,
self
)
.
setUp
()
super
(
ProgramViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
request
.
user
=
self
.
user
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
# Clear the cache between test cases, so they don't interfere with each other.
# Clear the cache between test cases, so they don't interfere with each other.
cache
.
clear
()
cache
.
clear
()
def
create_program
(
self
):
def
create_program
(
self
):
organizations
=
[
OrganizationFactory
()]
organizations
=
[
OrganizationFactory
(
partner
=
self
.
partner
)]
person
=
PersonFactory
()
person
=
PersonFactory
()
course
=
CourseFactory
()
course
=
CourseFactory
(
partner
=
self
.
partner
)
CourseRunFactory
(
course
=
course
,
staff
=
[
person
])
CourseRunFactory
(
course
=
course
,
staff
=
[
person
])
program
=
ProgramFactory
(
program
=
ProgramFactory
(
...
@@ -46,7 +46,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -46,7 +46,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
expected_learning_items
=
ExpectedLearningItemFactory
.
create_batch
(
1
),
expected_learning_items
=
ExpectedLearningItemFactory
.
create_batch
(
1
),
job_outlook_items
=
JobOutlookItemFactory
.
create_batch
(
1
),
job_outlook_items
=
JobOutlookItemFactory
.
create_batch
(
1
),
banner_image
=
make_image_file
(
'test_banner.jpg'
),
banner_image
=
make_image_file
(
'test_banner.jpg'
),
video
=
VideoFactory
()
video
=
VideoFactory
(),
partner
=
self
.
partner
)
)
return
program
return
program
...
@@ -73,14 +74,14 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -73,14 +74,14 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
test_retrieve
(
self
):
def
test_retrieve
(
self
):
""" Verify the endpoint returns the details for a single program. """
""" Verify the endpoint returns the details for a single program. """
program
=
self
.
create_program
()
program
=
self
.
create_program
()
with
self
.
assertNumQueries
(
3
7
):
with
self
.
assertNumQueries
(
3
9
):
response
=
self
.
assert_retrieve_success
(
program
)
response
=
self
.
assert_retrieve_success
(
program
)
# property does not have the right values while being indexed
# property does not have the right values while being indexed
del
program
.
_course_run_weeks_to_complete
del
program
.
_course_run_weeks_to_complete
assert
response
.
data
==
self
.
serialize_program
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
# Verify that repeated retrieve requests use the cache.
# Verify that repeated retrieve requests use the cache.
with
self
.
assertNumQueries
(
2
):
with
self
.
assertNumQueries
(
4
):
self
.
assert_retrieve_success
(
program
)
self
.
assert_retrieve_success
(
program
)
# Verify that requests including querystring parameters are cached separately.
# Verify that requests including querystring parameters are cached separately.
...
@@ -90,22 +91,25 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -90,22 +91,25 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
@ddt.data
(
True
,
False
)
@ddt.data
(
True
,
False
)
def
test_retrieve_with_sorting_flag
(
self
,
order_courses_by_start_date
):
def
test_retrieve_with_sorting_flag
(
self
,
order_courses_by_start_date
):
""" Verify the number of queries is the same with sorting flag set to true. """
""" Verify the number of queries is the same with sorting flag set to true. """
course_list
=
CourseFactory
.
create_batch
(
3
)
course_list
=
CourseFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
for
course
in
course_list
:
for
course
in
course_list
:
CourseRunFactory
(
course
=
course
)
CourseRunFactory
(
course
=
course
)
program
=
ProgramFactory
(
courses
=
course_list
,
order_courses_by_start_date
=
order_courses_by_start_date
)
program
=
ProgramFactory
(
courses
=
course_list
,
order_courses_by_start_date
=
order_courses_by_start_date
,
partner
=
self
.
partner
)
# property does not have the right values while being indexed
# property does not have the right values while being indexed
del
program
.
_course_run_weeks_to_complete
del
program
.
_course_run_weeks_to_complete
with
self
.
assertNumQueries
(
2
6
):
with
self
.
assertNumQueries
(
2
8
):
response
=
self
.
assert_retrieve_success
(
program
)
response
=
self
.
assert_retrieve_success
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
self
.
assertEqual
(
course_list
,
list
(
program
.
courses
.
all
()))
# pylint: disable=no-member
self
.
assertEqual
(
course_list
,
list
(
program
.
courses
.
all
()))
# pylint: disable=no-member
def
test_retrieve_without_course_runs
(
self
):
def
test_retrieve_without_course_runs
(
self
):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course
=
CourseFactory
()
course
=
CourseFactory
(
partner
=
self
.
partner
)
program
=
ProgramFactory
(
courses
=
[
course
])
program
=
ProgramFactory
(
courses
=
[
course
]
,
partner
=
self
.
partner
)
with
self
.
assertNumQueries
(
2
0
):
with
self
.
assertNumQueries
(
2
2
):
response
=
self
.
assert_retrieve_success
(
program
)
response
=
self
.
assert_retrieve_success
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
...
@@ -135,18 +139,18 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -135,18 +139,18 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all programs. """
""" Verify the endpoint returns a list of all programs. """
expected
=
[
self
.
create_program
()
for
__
in
range
(
3
)]
expected
=
[
self
.
create_program
()
for
__
in
range
(
3
)]
expected
.
reverse
()
expected
.
reverse
()
self
.
assert_list_results
(
self
.
list_path
,
expected
,
1
2
)
self
.
assert_list_results
(
self
.
list_path
,
expected
,
1
4
)
# Verify that repeated list requests use the cache.
# Verify that repeated list requests use the cache.
self
.
assert_list_results
(
self
.
list_path
,
expected
,
2
)
self
.
assert_list_results
(
self
.
list_path
,
expected
,
4
)
def
test_uuids_only
(
self
):
def
test_uuids_only
(
self
):
"""
"""
Verify that the list view returns a simply list of UUIDs when the
Verify that the list view returns a simply list of UUIDs when the
uuids_only query parameter is passed.
uuids_only query parameter is passed.
"""
"""
active
=
ProgramFactory
.
create_batch
(
3
)
active
=
ProgramFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
retired
=
[
ProgramFactory
(
status
=
ProgramStatus
.
Retired
)]
retired
=
[
ProgramFactory
(
status
=
ProgramStatus
.
Retired
,
partner
=
self
.
partner
)]
programs
=
active
+
retired
programs
=
active
+
retired
querystring
=
{
'uuids_only'
:
1
}
querystring
=
{
'uuids_only'
:
1
}
...
@@ -165,47 +169,47 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -165,47 +169,47 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
test_filter_by_type
(
self
):
def
test_filter_by_type
(
self
):
""" Verify that the endpoint filters programs to those of a given type. """
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name
=
'foo'
program_type_name
=
'foo'
program
=
ProgramFactory
(
type__name
=
program_type_name
)
program
=
ProgramFactory
(
type__name
=
program_type_name
,
partner
=
self
.
partner
)
url
=
self
.
list_path
+
'?type='
+
program_type_name
url
=
self
.
list_path
+
'?type='
+
program_type_name
self
.
assert_list_results
(
url
,
[
program
],
8
)
self
.
assert_list_results
(
url
,
[
program
],
10
)
url
=
self
.
list_path
+
'?type=bar'
url
=
self
.
list_path
+
'?type=bar'
self
.
assert_list_results
(
url
,
[],
3
)
self
.
assert_list_results
(
url
,
[],
5
)
def
test_filter_by_types
(
self
):
def
test_filter_by_types
(
self
):
""" Verify that the endpoint filters programs to those matching the provided ProgramType slugs. """
""" Verify that the endpoint filters programs to those matching the provided ProgramType slugs. """
expected
=
ProgramFactory
.
create_batch
(
2
)
expected
=
ProgramFactory
.
create_batch
(
2
,
partner
=
self
.
partner
)
expected
.
reverse
()
expected
.
reverse
()
type_slugs
=
[
p
.
type
.
slug
for
p
in
expected
]
type_slugs
=
[
p
.
type
.
slug
for
p
in
expected
]
url
=
self
.
list_path
+
'?types='
+
','
.
join
(
type_slugs
)
url
=
self
.
list_path
+
'?types='
+
','
.
join
(
type_slugs
)
# Create a third program, which should be filtered out.
# Create a third program, which should be filtered out.
ProgramFactory
()
ProgramFactory
(
partner
=
self
.
partner
)
self
.
assert_list_results
(
url
,
expected
,
8
)
self
.
assert_list_results
(
url
,
expected
,
10
)
def
test_filter_by_uuids
(
self
):
def
test_filter_by_uuids
(
self
):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
expected
=
ProgramFactory
.
create_batch
(
2
)
expected
=
ProgramFactory
.
create_batch
(
2
,
partner
=
self
.
partner
)
expected
.
reverse
()
expected
.
reverse
()
uuids
=
[
str
(
p
.
uuid
)
for
p
in
expected
]
uuids
=
[
str
(
p
.
uuid
)
for
p
in
expected
]
url
=
self
.
list_path
+
'?uuids='
+
','
.
join
(
uuids
)
url
=
self
.
list_path
+
'?uuids='
+
','
.
join
(
uuids
)
# Create a third program, which should be filtered out.
# Create a third program, which should be filtered out.
ProgramFactory
()
ProgramFactory
(
partner
=
self
.
partner
)
self
.
assert_list_results
(
url
,
expected
,
8
)
self
.
assert_list_results
(
url
,
expected
,
10
)
@ddt.data
(
@ddt.data
(
(
ProgramStatus
.
Unpublished
,
False
,
3
),
(
ProgramStatus
.
Unpublished
,
False
,
5
),
(
ProgramStatus
.
Active
,
True
,
8
),
(
ProgramStatus
.
Active
,
True
,
10
),
)
)
@ddt.unpack
@ddt.unpack
def
test_filter_by_marketable
(
self
,
status
,
is_marketable
,
expected_query_count
):
def
test_filter_by_marketable
(
self
,
status
,
is_marketable
,
expected_query_count
):
""" Verify the endpoint filters programs to those that are marketable. """
""" Verify the endpoint filters programs to those that are marketable. """
url
=
self
.
list_path
+
'?marketable=1'
url
=
self
.
list_path
+
'?marketable=1'
ProgramFactory
(
marketing_slug
=
''
)
ProgramFactory
(
marketing_slug
=
''
,
partner
=
self
.
partner
)
programs
=
ProgramFactory
.
create_batch
(
3
,
status
=
status
)
programs
=
ProgramFactory
.
create_batch
(
3
,
status
=
status
,
partner
=
self
.
partner
)
programs
.
reverse
()
programs
.
reverse
()
expected
=
programs
if
is_marketable
else
[]
expected
=
programs
if
is_marketable
else
[]
...
@@ -214,40 +218,40 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -214,40 +218,40 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
test_filter_by_status
(
self
):
def
test_filter_by_status
(
self
):
""" Verify the endpoint allows programs to filtered by one, or more, statuses. """
""" Verify the endpoint allows programs to filtered by one, or more, statuses. """
active
=
ProgramFactory
(
status
=
ProgramStatus
.
Active
)
active
=
ProgramFactory
(
status
=
ProgramStatus
.
Active
,
partner
=
self
.
partner
)
retired
=
ProgramFactory
(
status
=
ProgramStatus
.
Retired
)
retired
=
ProgramFactory
(
status
=
ProgramStatus
.
Retired
,
partner
=
self
.
partner
)
url
=
self
.
list_path
+
'?status=active'
url
=
self
.
list_path
+
'?status=active'
self
.
assert_list_results
(
url
,
[
active
],
8
)
self
.
assert_list_results
(
url
,
[
active
],
10
)
url
=
self
.
list_path
+
'?status=retired'
url
=
self
.
list_path
+
'?status=retired'
self
.
assert_list_results
(
url
,
[
retired
],
8
)
self
.
assert_list_results
(
url
,
[
retired
],
10
)
url
=
self
.
list_path
+
'?status=active&status=retired'
url
=
self
.
list_path
+
'?status=active&status=retired'
self
.
assert_list_results
(
url
,
[
retired
,
active
],
8
)
self
.
assert_list_results
(
url
,
[
retired
,
active
],
10
)
def
test_filter_by_hidden
(
self
):
def
test_filter_by_hidden
(
self
):
""" Endpoint should filter programs by their hidden attribute value. """
""" Endpoint should filter programs by their hidden attribute value. """
hidden
=
ProgramFactory
(
hidden
=
True
)
hidden
=
ProgramFactory
(
hidden
=
True
,
partner
=
self
.
partner
)
not_hidden
=
ProgramFactory
(
hidden
=
False
)
not_hidden
=
ProgramFactory
(
hidden
=
False
,
partner
=
self
.
partner
)
url
=
self
.
list_path
+
'?hidden=True'
url
=
self
.
list_path
+
'?hidden=True'
self
.
assert_list_results
(
url
,
[
hidden
],
8
)
self
.
assert_list_results
(
url
,
[
hidden
],
10
)
url
=
self
.
list_path
+
'?hidden=False'
url
=
self
.
list_path
+
'?hidden=False'
self
.
assert_list_results
(
url
,
[
not_hidden
],
8
)
self
.
assert_list_results
(
url
,
[
not_hidden
],
10
)
url
=
self
.
list_path
+
'?hidden=1'
url
=
self
.
list_path
+
'?hidden=1'
self
.
assert_list_results
(
url
,
[
hidden
],
8
)
self
.
assert_list_results
(
url
,
[
hidden
],
10
)
url
=
self
.
list_path
+
'?hidden=0'
url
=
self
.
list_path
+
'?hidden=0'
self
.
assert_list_results
(
url
,
[
not_hidden
],
8
)
self
.
assert_list_results
(
url
,
[
not_hidden
],
10
)
def
test_list_exclude_utm
(
self
):
def
test_list_exclude_utm
(
self
):
""" Verify the endpoint returns marketing URLs without UTM parameters. """
""" Verify the endpoint returns marketing URLs without UTM parameters. """
url
=
self
.
list_path
+
'?exclude_utm=1'
url
=
self
.
list_path
+
'?exclude_utm=1'
program
=
self
.
create_program
()
program
=
self
.
create_program
()
self
.
assert_list_results
(
url
,
[
program
],
1
2
,
extra_context
=
{
'exclude_utm'
:
1
})
self
.
assert_list_results
(
url
,
[
program
],
1
4
,
extra_context
=
{
'exclude_utm'
:
1
})
def
test_minimal_serializer_use
(
self
):
def
test_minimal_serializer_use
(
self
):
""" Verify that the list view uses the minimal serializer. """
""" Verify that the list view uses the minimal serializer. """
...
...
course_discovery/apps/api/v1/tests/test_views/test_search.py
View file @
e251012e
...
@@ -3,13 +3,12 @@ import json
...
@@ -3,13 +3,12 @@ import json
import
urllib.parse
import
urllib.parse
import
ddt
import
ddt
from
django.conf
import
settings
from
django.urls
import
reverse
from
django.urls
import
reverse
from
haystack.query
import
SearchQuerySet
from
haystack.query
import
SearchQuerySet
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.serializers
import
(
CourseRunSearchSerializer
,
ProgramSearchSerializer
,
from
course_discovery.apps.api.serializers
import
(
CourseRunSearchSerializer
,
ProgramSearchSerializer
,
TypeaheadCourseRunSearchSerializer
,
TypeaheadProgramSearchSerializer
)
TypeaheadCourseRunSearchSerializer
,
TypeaheadProgramSearchSerializer
)
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
from
course_discovery.apps.api.v1.views.search
import
TypeaheadSearchView
from
course_discovery.apps.api.v1.views.search
import
TypeaheadSearchView
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
PartnerFactory
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
PartnerFactory
,
UserFactory
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
...
@@ -88,14 +87,8 @@ class SynonymTestMixin:
...
@@ -88,14 +87,8 @@ class SynonymTestMixin:
self
.
assertDictEqual
(
response1
,
response2
)
self
.
assertDictEqual
(
response1
,
response2
)
class
DefaultPartnerMixin
:
def
setUp
(
self
):
super
(
DefaultPartnerMixin
,
self
)
.
setUp
()
self
.
partner
=
PartnerFactory
(
pk
=
settings
.
DEFAULT_PARTNER_ID
)
@ddt.ddt
@ddt.ddt
class
CourseRunSearchViewSetTests
(
DefaultPartnerMixin
,
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
class
CourseRunSearchViewSetTests
(
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
APITestCase
):
APITestCase
):
""" Tests for CourseRunSearchViewSet. """
""" Tests for CourseRunSearchViewSet. """
faceted_path
=
reverse
(
'api:v1:search-course_runs-facets'
)
faceted_path
=
reverse
(
'api:v1:search-course_runs-facets'
)
...
@@ -162,7 +155,9 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
...
@@ -162,7 +155,9 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
return
course_run
,
response_data
return
course_run
,
response_data
def
build_facet_url
(
self
,
params
):
def
build_facet_url
(
self
,
params
):
return
'http://testserver{path}?{query}'
.
format
(
path
=
self
.
faceted_path
,
query
=
urllib
.
parse
.
urlencode
(
params
))
return
'http://testserver.fake{path}?{query}'
.
format
(
path
=
self
.
faceted_path
,
query
=
urllib
.
parse
.
urlencode
(
params
)
)
def
test_invalid_query_facet
(
self
):
def
test_invalid_query_facet
(
self
):
""" Verify the endpoint returns HTTP 400 if an invalid facet is requested. """
""" Verify the endpoint returns HTTP 400 if an invalid facet is requested. """
...
@@ -271,7 +266,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
...
@@ -271,7 +266,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
)
)
self
.
reindex_courses
(
program
)
self
.
reindex_courses
(
program
)
with
self
.
assertNumQueries
(
4
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -295,7 +290,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
...
@@ -295,7 +290,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
ProgramFactory
(
courses
=
[
course_run
.
course
],
status
=
program_status
)
ProgramFactory
(
courses
=
[
course_run
.
course
],
status
=
program_status
)
self
.
reindex_courses
(
active_program
)
self
.
reindex_courses
(
active_program
)
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
6
):
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -313,7 +308,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
...
@@ -313,7 +308,7 @@ class CourseRunSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
@ddt.ddt
@ddt.ddt
class
AggregateSearchViewSetTests
(
DefaultPartnerMixin
,
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
class
AggregateSearchViewSetTests
(
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
SynonymTestMixin
,
APITestCase
):
path
=
reverse
(
'api:v1:search-all-facets'
)
path
=
reverse
(
'api:v1:search-all-facets'
)
...
@@ -438,7 +433,7 @@ class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
...
@@ -438,7 +433,7 @@ class AggregateSearchViewSetTests(DefaultPartnerMixin, SerializationMixin, Login
assert
expected
==
actual
assert
expected
==
actual
class
TypeaheadSearchViewTests
(
DefaultPartnerMixin
,
TypeaheadSerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
class
TypeaheadSearchViewTests
(
TypeaheadSerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
SynonymTestMixin
,
APITestCase
):
path
=
reverse
(
'api:v1:search-typeahead'
)
path
=
reverse
(
'api:v1:search-typeahead'
)
...
@@ -620,23 +615,3 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
...
@@ -620,23 +615,3 @@ class TypeaheadSearchViewTests(DefaultPartnerMixin, TypeaheadSerializationMixin,
self
.
serialize_program
(
harvard_program
)]
self
.
serialize_program
(
harvard_program
)]
}
}
self
.
assertDictEqual
(
response
.
data
,
expected
)
self
.
assertDictEqual
(
response
.
data
,
expected
)
def
test_typeahead_partner_filter
(
self
):
""" Ensure that a partner param limits results to that partner. """
course_runs
=
[]
programs
=
[]
for
partner
in
[
'edx'
,
'other'
]:
title
=
'Belongs to partner '
+
partner
partner
=
PartnerFactory
(
short_code
=
partner
)
course_runs
.
append
(
CourseRunFactory
(
title
=
title
,
course
=
CourseFactory
(
partner
=
partner
)))
programs
.
append
(
ProgramFactory
(
title
=
title
,
partner
=
partner
,
status
=
ProgramStatus
.
Active
))
response
=
self
.
get_response
({
'q'
:
'partner'
},
'edx'
)
self
.
assertEqual
(
response
.
status_code
,
200
)
edx_course_run
=
course_runs
[
0
]
edx_program
=
programs
[
0
]
self
.
assertDictEqual
(
response
.
data
,
{
'course_runs'
:
[
self
.
serialize_course_run
(
edx_course_run
)],
'programs'
:
[
self
.
serialize_program
(
edx_program
)]})
course_discovery/apps/api/v1/views/__init__.py
View file @
e251012e
...
@@ -35,18 +35,3 @@ def prefetch_related_objects_for_courses(queryset):
...
@@ -35,18 +35,3 @@ def prefetch_related_objects_for_courses(queryset):
queryset
=
queryset
.
select_related
(
*
_select_related_fields
[
'course'
])
queryset
=
queryset
.
select_related
(
*
_select_related_fields
[
'course'
])
queryset
=
queryset
.
prefetch_related
(
*
_prefetch_fields
[
'course'
])
queryset
=
queryset
.
prefetch_related
(
*
_prefetch_fields
[
'course'
])
return
queryset
return
queryset
class
PartnerMixin
:
def
get_partner
(
self
):
""" Return the partner for the short_code passed in or the default partner """
partner_code
=
self
.
request
.
query_params
.
get
(
'partner'
)
if
partner_code
:
try
:
partner
=
Partner
.
objects
.
get
(
short_code
=
partner_code
)
except
Partner
.
DoesNotExist
:
raise
InvalidPartnerError
(
'Unknown Partner: {}'
.
format
(
partner_code
))
else
:
partner
=
Partner
.
objects
.
get
(
id
=
settings
.
DEFAULT_PARTNER_ID
)
return
partner
course_discovery/apps/api/v1/views/catalogs.py
View file @
e251012e
...
@@ -94,6 +94,7 @@ class CatalogViewSet(viewsets.ModelViewSet):
...
@@ -94,6 +94,7 @@ class CatalogViewSet(viewsets.ModelViewSet):
course_runs
=
CourseRun
.
objects
.
active
()
.
enrollable
()
.
marketable
()
course_runs
=
CourseRun
.
objects
.
active
()
.
enrollable
()
.
marketable
()
queryset
=
serializers
.
CatalogCourseSerializer
.
prefetch_queryset
(
queryset
=
serializers
.
CatalogCourseSerializer
.
prefetch_queryset
(
self
.
request
.
site
.
partner
,
queryset
=
queryset
,
queryset
=
queryset
,
course_runs
=
course_runs
course_runs
=
course_runs
)
)
...
...
course_discovery/apps/api/v1/views/course_runs.py
View file @
e251012e
...
@@ -9,14 +9,13 @@ from rest_framework.response import Response
...
@@ -9,14 +9,13 @@ from rest_framework.response import Response
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api.pagination
import
ProxiedPagination
from
course_discovery.apps.api.pagination
import
ProxiedPagination
from
course_discovery.apps.api.utils
import
get_query_param
from
course_discovery.apps.api.utils
import
get_query_param
from
course_discovery.apps.api.v1.views
import
PartnerMixin
from
course_discovery.apps.core.utils
import
SearchQuerySetWrapper
from
course_discovery.apps.core.utils
import
SearchQuerySetWrapper
from
course_discovery.apps.course_metadata.constants
import
COURSE_RUN_ID_REGEX
from
course_discovery.apps.course_metadata.constants
import
COURSE_RUN_ID_REGEX
from
course_discovery.apps.course_metadata.models
import
CourseRun
from
course_discovery.apps.course_metadata.models
import
CourseRun
# pylint: disable=no-member
# pylint: disable=no-member
class
CourseRunViewSet
(
PartnerMixin
,
viewsets
.
ModelViewSet
):
class
CourseRunViewSet
(
viewsets
.
ModelViewSet
):
""" CourseRun resource. """
""" CourseRun resource. """
filter_backends
=
(
DjangoFilterBackend
,
OrderingFilter
)
filter_backends
=
(
DjangoFilterBackend
,
OrderingFilter
)
filter_class
=
filters
.
CourseRunFilter
filter_class
=
filters
.
CourseRunFilter
...
@@ -43,7 +42,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
...
@@ -43,7 +42,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
multiple: false
multiple: false
"""
"""
q
=
self
.
request
.
query_params
.
get
(
'q'
)
q
=
self
.
request
.
query_params
.
get
(
'q'
)
partner
=
self
.
get_partner
()
partner
=
self
.
request
.
site
.
partner
if
q
:
if
q
:
qs
=
SearchQuerySetWrapper
(
CourseRun
.
search
(
q
)
.
filter
(
partner
=
partner
.
short_code
))
qs
=
SearchQuerySetWrapper
(
CourseRun
.
search
(
q
)
.
filter
(
partner
=
partner
.
short_code
))
...
@@ -80,12 +79,6 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
...
@@ -80,12 +79,6 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
type: string
type: string
paramType: query
paramType: query
multiple: false
multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
- name: hidden
- name: hidden
description: Filter based on wether the course run is hidden from search.
description: Filter based on wether the course run is hidden from search.
required: false
required: false
...
@@ -166,7 +159,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
...
@@ -166,7 +159,7 @@ class CourseRunViewSet(PartnerMixin, viewsets.ModelViewSet):
"""
"""
query
=
request
.
GET
.
get
(
'query'
)
query
=
request
.
GET
.
get
(
'query'
)
course_run_ids
=
request
.
GET
.
get
(
'course_run_ids'
)
course_run_ids
=
request
.
GET
.
get
(
'course_run_ids'
)
partner
=
self
.
get_partner
()
partner
=
self
.
request
.
site
.
partner
if
query
and
course_run_ids
:
if
query
and
course_run_ids
:
course_run_ids
=
course_run_ids
.
split
(
','
)
course_run_ids
=
course_run_ids
.
split
(
','
)
...
...
course_discovery/apps/api/v1/views/courses.py
View file @
e251012e
...
@@ -18,7 +18,6 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -18,7 +18,6 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
filter_class
=
filters
.
CourseFilter
filter_class
=
filters
.
CourseFilter
lookup_field
=
'key'
lookup_field
=
'key'
lookup_value_regex
=
COURSE_ID_REGEX
lookup_value_regex
=
COURSE_ID_REGEX
queryset
=
Course
.
objects
.
all
()
permission_classes
=
(
IsAuthenticated
,)
permission_classes
=
(
IsAuthenticated
,)
serializer_class
=
serializers
.
CourseWithProgramsSerializer
serializer_class
=
serializers
.
CourseWithProgramsSerializer
...
@@ -27,16 +26,17 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -27,16 +26,17 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class
=
ProxiedPagination
pagination_class
=
ProxiedPagination
def
get_queryset
(
self
):
def
get_queryset
(
self
):
partner
=
self
.
request
.
site
.
partner
q
=
self
.
request
.
query_params
.
get
(
'q'
)
q
=
self
.
request
.
query_params
.
get
(
'q'
)
if
q
:
if
q
:
queryset
=
Course
.
search
(
q
)
queryset
=
Course
.
search
(
q
)
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
queryset
)
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
queryset
,
partner
=
partner
)
else
:
else
:
if
get_query_param
(
self
.
request
,
'include_hidden_course_runs'
):
if
get_query_param
(
self
.
request
,
'include_hidden_course_runs'
):
course_runs
=
CourseRun
.
objects
.
all
(
)
course_runs
=
CourseRun
.
objects
.
filter
(
course__partner
=
partner
)
else
:
else
:
course_runs
=
CourseRun
.
objects
.
exclude
(
hidden
=
True
)
course_runs
=
CourseRun
.
objects
.
filter
(
course__partner
=
partner
)
.
exclude
(
hidden
=
True
)
if
get_query_param
(
self
.
request
,
'marketable_course_runs_only'
):
if
get_query_param
(
self
.
request
,
'marketable_course_runs_only'
):
course_runs
=
course_runs
.
marketable
()
.
active
()
course_runs
=
course_runs
.
marketable
()
.
active
()
...
@@ -49,7 +49,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -49,7 +49,8 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
self
.
queryset
,
queryset
=
self
.
queryset
,
course_runs
=
course_runs
course_runs
=
course_runs
,
partner
=
partner
)
)
return
queryset
.
order_by
(
Lower
(
'key'
))
return
queryset
.
order_by
(
Lower
(
'key'
))
...
...
course_discovery/apps/api/v1/views/organizations.py
View file @
e251012e
...
@@ -15,13 +15,16 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -15,13 +15,16 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field
=
'uuid'
lookup_field
=
'uuid'
lookup_value_regex
=
'[0-9a-f-]+'
lookup_value_regex
=
'[0-9a-f-]+'
permission_classes
=
(
IsAuthenticated
,)
permission_classes
=
(
IsAuthenticated
,)
queryset
=
serializers
.
OrganizationSerializer
.
prefetch_queryset
()
serializer_class
=
serializers
.
OrganizationSerializer
serializer_class
=
serializers
.
OrganizationSerializer
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# versions of this API should only support the system default, PageNumberPagination.
# versions of this API should only support the system default, PageNumberPagination.
pagination_class
=
ProxiedPagination
pagination_class
=
ProxiedPagination
def
get_queryset
(
self
):
partner
=
self
.
request
.
site
.
partner
return
serializers
.
OrganizationSerializer
.
prefetch_queryset
(
partner
=
partner
)
def
list
(
self
,
request
,
*
args
,
**
kwargs
):
def
list
(
self
,
request
,
*
args
,
**
kwargs
):
""" Retrieve a list of all organizations. """
""" Retrieve a list of all organizations. """
return
super
(
OrganizationViewSet
,
self
)
.
list
(
request
,
*
args
,
**
kwargs
)
return
super
(
OrganizationViewSet
,
self
)
.
list
(
request
,
*
args
,
**
kwargs
)
...
...
course_discovery/apps/api/v1/views/people.py
View file @
e251012e
...
@@ -7,7 +7,6 @@ from rest_framework.response import Response
...
@@ -7,7 +7,6 @@ from rest_framework.response import Response
from
course_discovery.apps.api
import
serializers
from
course_discovery.apps.api
import
serializers
from
course_discovery.apps.api.pagination
import
PageNumberPagination
from
course_discovery.apps.api.pagination
import
PageNumberPagination
from
course_discovery.apps.api.v1.views
import
PartnerMixin
from
course_discovery.apps.course_metadata.exceptions
import
MarketingSiteAPIClientException
,
PersonToMarketingException
from
course_discovery.apps.course_metadata.exceptions
import
MarketingSiteAPIClientException
,
PersonToMarketingException
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
...
@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
...
@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=no-member
# pylint: disable=no-member
class
PersonViewSet
(
PartnerMixin
,
viewsets
.
ModelViewSet
):
class
PersonViewSet
(
viewsets
.
ModelViewSet
):
""" PersonSerializer resource. """
""" PersonSerializer resource. """
lookup_field
=
'uuid'
lookup_field
=
'uuid'
...
@@ -30,7 +29,7 @@ class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
...
@@ -30,7 +29,7 @@ class PersonViewSet(PartnerMixin, viewsets.ModelViewSet):
""" Create a new person. """
""" Create a new person. """
person_data
=
request
.
data
person_data
=
request
.
data
partner
=
self
.
get_partner
()
partner
=
request
.
site
.
partner
person_data
[
'partner'
]
=
partner
.
id
person_data
[
'partner'
]
=
partner
.
id
serializer
=
self
.
get_serializer
(
data
=
person_data
)
serializer
=
self
.
get_serializer
(
data
=
person_data
)
serializer
.
is_valid
(
raise_exception
=
True
)
serializer
.
is_valid
(
raise_exception
=
True
)
...
...
course_discovery/apps/api/v1/views/programs.py
View file @
e251012e
...
@@ -32,7 +32,8 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
...
@@ -32,7 +32,8 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
def
get_queryset
(
self
):
def
get_queryset
(
self
):
# This method prevents prefetches on the program queryset from "stacking,"
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
# which happens when the queryset is stored in a class property.
return
self
.
get_serializer_class
()
.
prefetch_queryset
()
partner
=
self
.
request
.
site
.
partner
return
self
.
get_serializer_class
()
.
prefetch_queryset
(
partner
)
def
get_serializer_context
(
self
,
*
args
,
**
kwargs
):
def
get_serializer_context
(
self
,
*
args
,
**
kwargs
):
context
=
super
()
.
get_serializer_context
(
*
args
,
**
kwargs
)
context
=
super
()
.
get_serializer_context
(
*
args
,
**
kwargs
)
...
@@ -89,7 +90,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
...
@@ -89,7 +90,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
if
get_query_param
(
self
.
request
,
'uuids_only'
):
if
get_query_param
(
self
.
request
,
'uuids_only'
):
# DRF serializers don't have good support for simple, flat
# DRF serializers don't have good support for simple, flat
# representations like the one we want here.
# representations like the one we want here.
queryset
=
self
.
filter_queryset
(
Program
.
objects
.
all
(
))
queryset
=
self
.
filter_queryset
(
Program
.
objects
.
filter
(
partner
=
self
.
request
.
site
.
partner
))
uuids
=
queryset
.
values_list
(
'uuid'
,
flat
=
True
)
uuids
=
queryset
.
values_list
(
'uuid'
,
flat
=
True
)
return
Response
(
uuids
)
return
Response
(
uuids
)
...
...
course_discovery/apps/api/v1/views/search.py
View file @
e251012e
...
@@ -12,7 +12,6 @@ from rest_framework.response import Response
...
@@ -12,7 +12,6 @@ from rest_framework.response import Response
from
rest_framework.views
import
APIView
from
rest_framework.views
import
APIView
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api.v1.views
import
PartnerMixin
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Program
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Program
...
@@ -119,7 +118,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
...
@@ -119,7 +118,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class
=
serializers
.
AggregateSearchSerializer
serializer_class
=
serializers
.
AggregateSearchSerializer
class
TypeaheadSearchView
(
PartnerMixin
,
APIView
):
class
TypeaheadSearchView
(
APIView
):
""" Typeahead for courses and programs. """
""" Typeahead for courses and programs. """
RESULT_COUNT
=
3
RESULT_COUNT
=
3
permission_classes
=
(
IsAuthenticated
,)
permission_classes
=
(
IsAuthenticated
,)
...
@@ -181,7 +180,7 @@ class TypeaheadSearchView(PartnerMixin, APIView):
...
@@ -181,7 +180,7 @@ class TypeaheadSearchView(PartnerMixin, APIView):
type: string
type: string
"""
"""
query
=
request
.
query_params
.
get
(
'q'
)
query
=
request
.
query_params
.
get
(
'q'
)
partner
=
self
.
get_partner
()
partner
=
request
.
site
.
partner
if
not
query
:
if
not
query
:
raise
ValidationError
(
"The 'q' querystring parameter is required for searching."
)
raise
ValidationError
(
"The 'q' querystring parameter is required for searching."
)
course_runs
,
programs
=
self
.
get_results
(
query
,
partner
)
course_runs
,
programs
=
self
.
get_results
(
query
,
partner
)
...
...
course_discovery/apps/core/tests/test_lookups.py
View file @
e251012e
...
@@ -3,10 +3,11 @@ import json
...
@@ -3,10 +3,11 @@ import json
from
django.test
import
TestCase
from
django.test
import
TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
class
UserAutocompleteTests
(
TestCase
):
class
UserAutocompleteTests
(
SiteMixin
,
TestCase
):
""" Tests for user autocomplete lookups."""
""" Tests for user autocomplete lookups."""
def
setUp
(
self
):
def
setUp
(
self
):
...
...
course_discovery/apps/core/tests/test_throttles.py
View file @
e251012e
from
django.conf
import
settings
from
django.core.cache
import
cache
from
django.core.cache
import
cache
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.models
import
UserThrottleRate
from
course_discovery.apps.core.models
import
UserThrottleRate
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
PartnerFactory
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.throttles
import
OverridableUserRateThrottle
from
course_discovery.apps.core.throttles
import
OverridableUserRateThrottle
class
RateLimitingTest
(
APITestCase
):
class
RateLimitingTest
(
SiteMixin
,
APITestCase
):
"""
"""
Testing rate limiting of API calls.
Testing rate limiting of API calls.
"""
"""
...
@@ -16,8 +16,6 @@ class RateLimitingTest(APITestCase):
...
@@ -16,8 +16,6 @@ class RateLimitingTest(APITestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
RateLimitingTest
,
self
)
.
setUp
()
super
(
RateLimitingTest
,
self
)
.
setUp
()
PartnerFactory
(
pk
=
settings
.
DEFAULT_PARTNER_ID
)
self
.
url
=
reverse
(
'api_docs'
)
self
.
url
=
reverse
(
'api_docs'
)
self
.
user
=
UserFactory
()
self
.
user
=
UserFactory
()
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
...
...
course_discovery/apps/core/tests/test_views.py
View file @
e251012e
...
@@ -9,18 +9,20 @@ from django.test.utils import override_settings
...
@@ -9,18 +9,20 @@ from django.test.utils import override_settings
from
django.urls
import
reverse
from
django.urls
import
reverse
from
django.utils.encoding
import
force_text
from
django.utils.encoding
import
force_text
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.constants
import
Status
from
course_discovery.apps.core.constants
import
Status
User
=
get_user_model
()
User
=
get_user_model
()
class
HealthTests
(
TestCase
):
class
HealthTests
(
SiteMixin
,
TestCase
):
"""Tests of the health endpoint."""
"""Tests of the health endpoint."""
def
test_all_services_available
(
self
):
def
test_all_services_available
(
self
):
"""Test that the endpoint reports when all services are healthy."""
"""Test that the endpoint reports when all services are healthy."""
self
.
_assert_health
(
200
,
Status
.
OK
,
Status
.
OK
)
self
.
_assert_health
(
200
,
Status
.
OK
,
Status
.
OK
)
@mock.patch
(
'django.contrib.sites.middleware.get_current_site'
,
mock
.
Mock
(
return_value
=
None
))
def
test_database_outage
(
self
):
def
test_database_outage
(
self
):
"""Test that the endpoint reports when the database is unavailable."""
"""Test that the endpoint reports when the database is unavailable."""
with
mock
.
patch
(
'django.db.backends.base.base.BaseDatabaseWrapper.cursor'
,
side_effect
=
DatabaseError
):
with
mock
.
patch
(
'django.db.backends.base.base.BaseDatabaseWrapper.cursor'
,
side_effect
=
DatabaseError
):
...
@@ -42,7 +44,7 @@ class HealthTests(TestCase):
...
@@ -42,7 +44,7 @@ class HealthTests(TestCase):
self
.
assertJSONEqual
(
force_text
(
response
.
content
),
expected_data
)
self
.
assertJSONEqual
(
force_text
(
response
.
content
),
expected_data
)
class
AutoAuthTests
(
TestCase
):
class
AutoAuthTests
(
SiteMixin
,
TestCase
):
""" Auto Auth view tests. """
""" Auto Auth view tests. """
AUTO_AUTH_PATH
=
reverse
(
'auto_auth'
)
AUTO_AUTH_PATH
=
reverse
(
'auto_auth'
)
...
...
course_discovery/apps/course_metadata/tests/test_admin.py
View file @
e251012e
...
@@ -11,6 +11,7 @@ from selenium.webdriver.support import expected_conditions as EC
...
@@ -11,6 +11,7 @@ from selenium.webdriver.support import expected_conditions as EC
from
selenium.webdriver.support.ui
import
Select
from
selenium.webdriver.support.ui
import
Select
from
selenium.webdriver.support.wait
import
WebDriverWait
from
selenium.webdriver.support.wait
import
WebDriverWait
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.models
import
Partner
from
course_discovery.apps.core.models
import
Partner
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.core.tests.helpers
import
make_image_file
...
@@ -23,7 +24,7 @@ from course_discovery.apps.course_metadata.tests import factories
...
@@ -23,7 +24,7 @@ from course_discovery.apps.course_metadata.tests import factories
# pylint: disable=no-member
# pylint: disable=no-member
@ddt.ddt
@ddt.ddt
class
AdminTests
(
TestCase
):
class
AdminTests
(
SiteMixin
,
TestCase
):
""" Tests Admin page."""
""" Tests Admin page."""
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -190,7 +191,7 @@ class AdminTests(TestCase):
...
@@ -190,7 +191,7 @@ class AdminTests(TestCase):
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
class
ProgramAdminFunctionalTests
(
LiveServerTestCase
):
class
ProgramAdminFunctionalTests
(
SiteMixin
,
LiveServerTestCase
):
""" Functional Tests for Admin page."""
""" Functional Tests for Admin page."""
# Required for access to initial data loaded in migrations (e.g., LanguageTags).
# Required for access to initial data loaded in migrations (e.g., LanguageTags).
serialized_rollback
=
True
serialized_rollback
=
True
...
@@ -224,7 +225,6 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
...
@@ -224,7 +225,6 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
()
.
setUp
()
super
()
.
setUp
()
# ContentTypeManager uses a cache to speed up ContentType retrieval. This
# ContentTypeManager uses a cache to speed up ContentType retrieval. This
# cache persists across tests. This is fine in the context of a regular
# cache persists across tests. This is fine in the context of a regular
# TestCase which uses a transaction to reset the database between tests.
# TestCase which uses a transaction to reset the database between tests.
...
@@ -238,6 +238,9 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
...
@@ -238,6 +238,9 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
# stale ContentType objects from being used.
# stale ContentType objects from being used.
ContentType
.
objects
.
clear_cache
()
ContentType
.
objects
.
clear_cache
()
self
.
site
.
domain
=
self
.
live_server_url
.
strip
(
'http://'
)
self
.
site
.
save
()
self
.
course_runs
=
factories
.
CourseRunFactory
.
create_batch
(
2
)
self
.
course_runs
=
factories
.
CourseRunFactory
.
create_batch
(
2
)
self
.
courses
=
[
course_run
.
course
for
course_run
in
self
.
course_runs
]
self
.
courses
=
[
course_run
.
course
for
course_run
in
self
.
course_runs
]
...
@@ -349,7 +352,7 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
...
@@ -349,7 +352,7 @@ class ProgramAdminFunctionalTests(LiveServerTestCase):
self
.
assertEqual
(
self
.
program
.
subtitle
,
subtitle
)
self
.
assertEqual
(
self
.
program
.
subtitle
,
subtitle
)
class
ProgramEligibilityFilterTests
(
TestCase
):
class
ProgramEligibilityFilterTests
(
SiteMixin
,
TestCase
):
""" Tests for Program Eligibility Filter class. """
""" Tests for Program Eligibility Filter class. """
parameter_name
=
'eligible_for_one_click_purchase'
parameter_name
=
'eligible_for_one_click_purchase'
...
...
course_discovery/apps/course_metadata/tests/test_lookups.py
View file @
e251012e
...
@@ -5,6 +5,7 @@ import ddt
...
@@ -5,6 +5,7 @@ import ddt
from
django.test
import
TestCase
from
django.test
import
TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.course_metadata.tests.factories
import
(
from
course_discovery.apps.course_metadata.tests.factories
import
(
CourseFactory
,
CourseRunFactory
,
OrganizationFactory
,
PersonFactory
,
PositionFactory
CourseFactory
,
CourseRunFactory
,
OrganizationFactory
,
PersonFactory
,
PositionFactory
...
@@ -16,7 +17,7 @@ from course_discovery.apps.publisher.tests import factories
...
@@ -16,7 +17,7 @@ from course_discovery.apps.publisher.tests import factories
@ddt.ddt
@ddt.ddt
class
AutocompleteTests
(
TestCase
):
class
AutocompleteTests
(
SiteMixin
,
TestCase
):
""" Tests for autocomplete lookups."""
""" Tests for autocomplete lookups."""
def
setUp
(
self
):
def
setUp
(
self
):
super
(
AutocompleteTests
,
self
)
.
setUp
()
super
(
AutocompleteTests
,
self
)
.
setUp
()
...
@@ -118,7 +119,7 @@ class AutocompleteTests(TestCase):
...
@@ -118,7 +119,7 @@ class AutocompleteTests(TestCase):
@ddt.ddt
@ddt.ddt
class
AutoCompletePersonTests
(
TestCase
):
class
AutoCompletePersonTests
(
SiteMixin
,
TestCase
):
"""
"""
Tests for person autocomplete lookups
Tests for person autocomplete lookups
"""
"""
...
...
course_discovery/apps/edx_catalog_extensions/api/v1/tests/test_views.py
View file @
e251012e
...
@@ -2,17 +2,17 @@ import datetime
...
@@ -2,17 +2,17 @@ import datetime
import
urllib.parse
import
urllib.parse
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.test_search
import
(
from
course_discovery.apps.api.v1.tests.test_views.test_search
import
(
DefaultPartnerMixin
,
ElasticsearchTestMixin
,
LoginMixin
,
SerializationMixin
,
SynonymTestMixin
ElasticsearchTestMixin
,
LoginMixin
,
SerializationMixin
,
SynonymTestMixin
)
)
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.tests.factories
import
CourseFactory
,
CourseRunFactory
,
ProgramFactory
from
course_discovery.apps.course_metadata.tests.factories
import
CourseFactory
,
CourseRunFactory
,
ProgramFactory
from
course_discovery.apps.edx_catalog_extensions.api.serializers
import
DistinctCountsAggregateFacetSearchSerializer
from
course_discovery.apps.edx_catalog_extensions.api.serializers
import
DistinctCountsAggregateFacetSearchSerializer
class
DistinctCountsAggregateSearchViewSetTests
(
DefaultPartnerMixin
,
SerializationMixin
,
LoginMixin
,
class
DistinctCountsAggregateSearchViewSetTests
(
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
path
=
reverse
(
'extensions:api:v1:search-all-facets'
)
path
=
reverse
(
'extensions:api:v1:search-all-facets'
)
...
...
course_discovery/apps/ietf_language_tags/tests/test_lookups.py
View file @
e251012e
...
@@ -4,13 +4,14 @@ import ddt
...
@@ -4,13 +4,14 @@ import ddt
from
django.test
import
TestCase
from
django.test
import
TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.ietf_language_tags.models
import
LanguageTag
from
course_discovery.apps.ietf_language_tags.models
import
LanguageTag
# pylint: disable=no-member
# pylint: disable=no-member
@ddt.ddt
@ddt.ddt
class
AutocompleteTests
(
TestCase
):
class
AutocompleteTests
(
SiteMixin
,
TestCase
):
""" Tests for autocomplete lookups."""
""" Tests for autocomplete lookups."""
def
setUp
(
self
):
def
setUp
(
self
):
super
(
AutocompleteTests
,
self
)
.
setUp
()
super
(
AutocompleteTests
,
self
)
.
setUp
()
...
...
course_discovery/apps/publisher/api/serializers.py
View file @
e251012e
...
@@ -40,7 +40,8 @@ class CourseUserRoleSerializer(serializers.ModelSerializer):
...
@@ -40,7 +40,8 @@ class CourseUserRoleSerializer(serializers.ModelSerializer):
former_user
=
instance
.
user
former_user
=
instance
.
user
instance
=
super
(
CourseUserRoleSerializer
,
self
)
.
update
(
instance
,
validated_data
)
instance
=
super
(
CourseUserRoleSerializer
,
self
)
.
update
(
instance
,
validated_data
)
if
not
instance
.
role
==
PublisherUserRole
.
CourseTeam
:
if
not
instance
.
role
==
PublisherUserRole
.
CourseTeam
:
send_change_role_assignment_email
(
instance
,
former_user
)
request
=
self
.
context
[
'request'
]
send_change_role_assignment_email
(
instance
,
former_user
,
request
.
site
)
return
instance
return
instance
...
@@ -104,6 +105,7 @@ class CourseRunSerializer(serializers.ModelSerializer):
...
@@ -104,6 +105,7 @@ class CourseRunSerializer(serializers.ModelSerializer):
instance
=
super
(
CourseRunSerializer
,
self
)
.
update
(
instance
,
validated_data
)
instance
=
super
(
CourseRunSerializer
,
self
)
.
update
(
instance
,
validated_data
)
preview_url
=
validated_data
.
get
(
'preview_url'
)
preview_url
=
validated_data
.
get
(
'preview_url'
)
lms_course_id
=
validated_data
.
get
(
'lms_course_id'
)
lms_course_id
=
validated_data
.
get
(
'lms_course_id'
)
request
=
self
.
context
[
'request'
]
if
preview_url
:
if
preview_url
:
# Change ownership to CourseTeam.
# Change ownership to CourseTeam.
...
@@ -111,10 +113,10 @@ class CourseRunSerializer(serializers.ModelSerializer):
...
@@ -111,10 +113,10 @@ class CourseRunSerializer(serializers.ModelSerializer):
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
if
preview_url
:
if
preview_url
:
send_email_preview_page_is_available
(
instance
)
send_email_preview_page_is_available
(
instance
,
site
=
request
.
site
)
elif
lms_course_id
:
elif
lms_course_id
:
send_email_for_studio_instance_created
(
instance
)
send_email_for_studio_instance_created
(
instance
,
site
=
request
.
site
)
return
instance
return
instance
...
@@ -167,7 +169,7 @@ class CourseStateSerializer(serializers.ModelSerializer):
...
@@ -167,7 +169,7 @@ class CourseStateSerializer(serializers.ModelSerializer):
state
=
validated_data
.
get
(
'name'
)
state
=
validated_data
.
get
(
'name'
)
request
=
self
.
context
.
get
(
'request'
)
request
=
self
.
context
.
get
(
'request'
)
try
:
try
:
instance
.
change_state
(
state
=
state
,
user
=
request
.
user
)
instance
.
change_state
(
state
=
state
,
user
=
request
.
user
,
site
=
request
.
site
)
except
TransitionNotAllowed
:
except
TransitionNotAllowed
:
# pylint: disable=no-member
# pylint: disable=no-member
raise
serializers
.
ValidationError
(
raise
serializers
.
ValidationError
(
...
@@ -204,7 +206,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
...
@@ -204,7 +206,7 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
if
state
:
if
state
:
try
:
try
:
instance
.
change_state
(
state
=
state
,
user
=
request
.
user
)
instance
.
change_state
(
state
=
state
,
user
=
request
.
user
,
site
=
request
.
site
)
except
TransitionNotAllowed
:
except
TransitionNotAllowed
:
# pylint: disable=no-member
# pylint: disable=no-member
raise
serializers
.
ValidationError
(
raise
serializers
.
ValidationError
(
...
@@ -223,6 +225,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
...
@@ -223,6 +225,6 @@ class CourseRunStateSerializer(serializers.ModelSerializer):
instance
.
save
()
instance
.
save
()
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
send_email_preview_accepted
(
instance
.
course_run
)
send_email_preview_accepted
(
instance
.
course_run
,
request
.
site
)
return
instance
return
instance
course_discovery/apps/publisher/api/tests/test_serializers.py
View file @
e251012e
...
@@ -5,6 +5,7 @@ from django.test import RequestFactory, TestCase
...
@@ -5,6 +5,7 @@ from django.test import RequestFactory, TestCase
from
opaque_keys.edx.keys
import
CourseKey
from
opaque_keys.edx.keys
import
CourseKey
from
rest_framework.exceptions
import
ValidationError
from
rest_framework.exceptions
import
ValidationError
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
...
@@ -20,7 +21,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour
...
@@ -20,7 +21,7 @@ from course_discovery.apps.publisher.tests.factories import (CourseFactory, Cour
OrganizationExtensionFactory
,
SeatFactory
)
OrganizationExtensionFactory
,
SeatFactory
)
class
CourseUserRoleSerializerTests
(
TestCase
):
class
CourseUserRoleSerializerTests
(
SiteMixin
,
TestCase
):
serializer_class
=
CourseUserRoleSerializer
serializer_class
=
CourseUserRoleSerializer
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -28,6 +29,7 @@ class CourseUserRoleSerializerTests(TestCase):
...
@@ -28,6 +29,7 @@ class CourseUserRoleSerializerTests(TestCase):
self
.
request
=
RequestFactory
()
self
.
request
=
RequestFactory
()
self
.
course_user_role
=
CourseUserRoleFactory
(
role
=
PublisherUserRole
.
MarketingReviewer
)
self
.
course_user_role
=
CourseUserRoleFactory
(
role
=
PublisherUserRole
.
MarketingReviewer
)
self
.
request
.
user
=
self
.
course_user_role
.
user
self
.
request
.
user
=
self
.
course_user_role
.
user
self
.
request
.
site
=
self
.
site
def
get_expected_data
(
self
):
def
get_expected_data
(
self
):
""" Helper method which will return expected serialize data. """
""" Helper method which will return expected serialize data. """
...
@@ -138,7 +140,7 @@ class CourseRunSerializerTests(TestCase):
...
@@ -138,7 +140,7 @@ class CourseRunSerializerTests(TestCase):
"""
"""
self
.
course_run
.
preview_url
=
''
self
.
course_run
.
preview_url
=
''
self
.
course_run
.
save
()
self
.
course_run
.
save
()
serializer
=
self
.
serializer_class
(
self
.
course_run
)
serializer
=
self
.
serializer_class
(
self
.
course_run
,
context
=
{
'request'
:
self
.
request
}
)
serializer
.
update
(
self
.
course_run
,
{
'preview_url'
:
'https://example.com/abc/course'
})
serializer
.
update
(
self
.
course_run
,
{
'preview_url'
:
'https://example.com/abc/course'
})
self
.
assertEqual
(
self
.
course_state
.
owner_role
,
PublisherUserRole
.
CourseTeam
)
self
.
assertEqual
(
self
.
course_state
.
owner_role
,
PublisherUserRole
.
CourseTeam
)
...
@@ -246,13 +248,12 @@ class CourseRevisionSerializerTests(TestCase):
...
@@ -246,13 +248,12 @@ class CourseRevisionSerializerTests(TestCase):
self
.
assertDictEqual
(
serializer
.
data
,
expected
)
self
.
assertDictEqual
(
serializer
.
data
,
expected
)
class
CourseStateSerializerTests
(
TestCase
):
class
CourseStateSerializerTests
(
SiteMixin
,
TestCase
):
serializer_class
=
CourseStateSerializer
serializer_class
=
CourseStateSerializer
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CourseStateSerializerTests
,
self
)
.
setUp
()
super
(
CourseStateSerializerTests
,
self
)
.
setUp
()
self
.
course_state
=
CourseStateFactory
(
name
=
CourseStateChoices
.
Draft
)
self
.
course_state
=
CourseStateFactory
(
name
=
CourseStateChoices
.
Draft
)
self
.
request
=
RequestFactory
()
self
.
user
=
UserFactory
()
self
.
user
=
UserFactory
()
self
.
request
.
user
=
self
.
user
self
.
request
.
user
=
self
.
user
...
@@ -289,14 +290,13 @@ class CourseStateSerializerTests(TestCase):
...
@@ -289,14 +290,13 @@ class CourseStateSerializerTests(TestCase):
serializer
.
update
(
self
.
course_state
,
data
)
serializer
.
update
(
self
.
course_state
,
data
)
class
CourseRunStateSerializerTests
(
TestCase
):
class
CourseRunStateSerializerTests
(
SiteMixin
,
TestCase
):
serializer_class
=
CourseRunStateSerializer
serializer_class
=
CourseRunStateSerializer
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CourseRunStateSerializerTests
,
self
)
.
setUp
()
super
(
CourseRunStateSerializerTests
,
self
)
.
setUp
()
self
.
run_state
=
CourseRunStateFactory
(
name
=
CourseRunStateChoices
.
Draft
)
self
.
run_state
=
CourseRunStateFactory
(
name
=
CourseRunStateChoices
.
Draft
)
self
.
course_run
=
self
.
run_state
.
course_run
self
.
course_run
=
self
.
run_state
.
course_run
self
.
request
=
RequestFactory
()
self
.
user
=
UserFactory
()
self
.
user
=
UserFactory
()
self
.
request
.
user
=
self
.
user
self
.
request
.
user
=
self
.
user
CourseStateFactory
(
name
=
CourseStateChoices
.
Approved
,
course
=
self
.
course_run
.
course
)
CourseStateFactory
(
name
=
CourseStateChoices
.
Approved
,
course
=
self
.
course_run
.
course
)
...
...
course_discovery/apps/publisher/api/tests/test_views.py
View file @
e251012e
...
@@ -4,7 +4,6 @@ from urllib.parse import quote
...
@@ -4,7 +4,6 @@ from urllib.parse import quote
import
ddt
import
ddt
from
django.contrib.auth.models
import
Group
from
django.contrib.auth.models
import
Group
from
django.contrib.sites.models
import
Site
from
django.core
import
mail
from
django.core
import
mail
from
django.db
import
IntegrityError
from
django.db
import
IntegrityError
from
django.test
import
TestCase
from
django.test
import
TestCase
...
@@ -14,6 +13,7 @@ from mock import mock, patch
...
@@ -14,6 +13,7 @@ from mock import mock, patch
from
opaque_keys.edx.keys
import
CourseKey
from
opaque_keys.edx.keys
import
CourseKey
from
testfixtures
import
LogCapture
from
testfixtures
import
LogCapture
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
...
@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories
...
@@ -28,7 +28,7 @@ from course_discovery.apps.publisher.tests import JSON_CONTENT_TYPE, factories
@ddt.ddt
@ddt.ddt
class
CourseRoleAssignmentViewTests
(
TestCase
):
class
CourseRoleAssignmentViewTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CourseRoleAssignmentViewTests
,
self
)
.
setUp
()
super
(
CourseRoleAssignmentViewTests
,
self
)
.
setUp
()
...
@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(TestCase):
...
@@ -139,7 +139,7 @@ class CourseRoleAssignmentViewTests(TestCase):
self
.
assertEqual
(
len
(
mail
.
outbox
),
1
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
1
)
class
OrganizationGroupUserViewTests
(
TestCase
):
class
OrganizationGroupUserViewTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
OrganizationGroupUserViewTests
,
self
)
.
setUp
()
super
(
OrganizationGroupUserViewTests
,
self
)
.
setUp
()
...
@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(TestCase):
...
@@ -189,7 +189,7 @@ class OrganizationGroupUserViewTests(TestCase):
)
)
class
UpdateCourseRunViewTests
(
TestCase
):
class
UpdateCourseRunViewTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
UpdateCourseRunViewTests
,
self
)
.
setUp
()
super
(
UpdateCourseRunViewTests
,
self
)
.
setUp
()
...
@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(TestCase):
...
@@ -313,7 +313,7 @@ class UpdateCourseRunViewTests(TestCase):
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
self
.
assertIn
(
expected_body
,
body
)
self
.
assertIn
(
expected_body
,
body
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
object_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
object_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
def
test_update_preview_url
(
self
):
def
test_update_preview_url
(
self
):
...
@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(TestCase):
...
@@ -377,7 +377,7 @@ class UpdateCourseRunViewTests(TestCase):
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
class
CourseRevisionDetailViewTests
(
TestCase
):
class
CourseRevisionDetailViewTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CourseRevisionDetailViewTests
,
self
)
.
setUp
()
super
(
CourseRevisionDetailViewTests
,
self
)
.
setUp
()
...
@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(TestCase):
...
@@ -431,7 +431,7 @@ class CourseRevisionDetailViewTests(TestCase):
return
self
.
client
.
get
(
path
=
course_revision_path
)
return
self
.
client
.
get
(
path
=
course_revision_path
)
class
ChangeCourseStateViewTests
(
TestCase
):
class
ChangeCourseStateViewTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
ChangeCourseStateViewTests
,
self
)
.
setUp
()
super
(
ChangeCourseStateViewTests
,
self
)
.
setUp
()
...
@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(TestCase):
...
@@ -530,7 +530,7 @@ class ChangeCourseStateViewTests(TestCase):
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
object_path
=
reverse
(
'publisher:publisher_course_detail'
,
kwargs
=
{
'pk'
:
self
.
course
.
id
})
object_path
=
reverse
(
'publisher:publisher_course_detail'
,
kwargs
=
{
'pk'
:
self
.
course
.
id
})
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
object_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
object_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
def
test_change_course_state_with_error
(
self
):
def
test_change_course_state_with_error
(
self
):
...
@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(TestCase):
...
@@ -587,7 +587,7 @@ class ChangeCourseStateViewTests(TestCase):
self
.
_assert_email_sent
(
course_team_user
,
subject
)
self
.
_assert_email_sent
(
course_team_user
,
subject
)
class
ChangeCourseRunStateViewTests
(
TestCase
):
class
ChangeCourseRunStateViewTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
ChangeCourseRunStateViewTests
,
self
)
.
setUp
()
super
(
ChangeCourseRunStateViewTests
,
self
)
.
setUp
()
...
@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(TestCase):
...
@@ -796,7 +796,7 @@ class ChangeCourseRunStateViewTests(TestCase):
self
.
assertIn
(
'has been published'
,
mail
.
outbox
[
0
]
.
body
.
strip
())
self
.
assertIn
(
'has been published'
,
mail
.
outbox
[
0
]
.
body
.
strip
())
class
RevertCourseByRevisionTests
(
TestCase
):
class
RevertCourseByRevisionTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
RevertCourseByRevisionTests
,
self
)
.
setUp
()
super
(
RevertCourseByRevisionTests
,
self
)
.
setUp
()
...
@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(TestCase):
...
@@ -860,7 +860,7 @@ class RevertCourseByRevisionTests(TestCase):
return
self
.
client
.
put
(
path
=
course_revision_path
)
return
self
.
client
.
put
(
path
=
course_revision_path
)
class
CoursesAutoCompleteTests
(
TestCase
):
class
CoursesAutoCompleteTests
(
SiteMixin
,
TestCase
):
""" Tests for course autocomplete."""
""" Tests for course autocomplete."""
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(TestCase):
...
@@ -927,7 +927,7 @@ class CoursesAutoCompleteTests(TestCase):
self
.
assertEqual
(
len
(
data
[
'results'
]),
expected_length
)
self
.
assertEqual
(
len
(
data
[
'results'
]),
expected_length
)
class
AcceptAllByRevisionTests
(
TestCase
):
class
AcceptAllByRevisionTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
AcceptAllByRevisionTests
,
self
)
.
setUp
()
super
(
AcceptAllByRevisionTests
,
self
)
.
setUp
()
...
...
course_discovery/apps/publisher/emails.py
View file @
e251012e
import
logging
import
logging
from
django.conf
import
settings
from
django.conf
import
settings
from
django.contrib.sites.models
import
Site
from
django.core.mail.message
import
EmailMultiAlternatives
from
django.core.mail.message
import
EmailMultiAlternatives
from
django.template.loader
import
get_template
from
django.template.loader
import
get_template
from
django.urls
import
reverse
from
django.urls
import
reverse
...
@@ -16,11 +15,12 @@ from course_discovery.apps.publisher.utils import is_email_notification_enabled
...
@@ -16,11 +15,12 @@ from course_discovery.apps.publisher.utils import is_email_notification_enabled
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
send_email_for_studio_instance_created
(
course_run
):
def
send_email_for_studio_instance_created
(
course_run
,
site
):
""" Send an email to course team on studio instance creation.
""" Send an email to course team on studio instance creation.
Arguments:
Arguments:
course_run (CourseRun): CourseRun object
course_run (CourseRun): CourseRun object
site (Site): Current site
"""
"""
try
:
try
:
course_key
=
CourseKey
.
from_string
(
course_run
.
lms_course_id
)
course_key
=
CourseKey
.
from_string
(
course_run
.
lms_course_id
)
...
@@ -39,7 +39,7 @@ def send_email_for_studio_instance_created(course_run):
...
@@ -39,7 +39,7 @@ def send_email_for_studio_instance_created(course_run):
context
=
{
context
=
{
'course_run'
:
course_run
,
'course_run'
:
course_run
,
'course_run_page_url'
:
'https://{host}{path}'
.
format
(
'course_run_page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
object_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
object_path
),
),
'course_name'
:
course_run
.
course
.
title
,
'course_name'
:
course_run
.
course
.
title
,
'from_address'
:
from_address
,
'from_address'
:
from_address
,
...
@@ -65,12 +65,13 @@ def send_email_for_studio_instance_created(course_run):
...
@@ -65,12 +65,13 @@ def send_email_for_studio_instance_created(course_run):
raise
Exception
(
error_message
)
raise
Exception
(
error_message
)
def
send_email_for_course_creation
(
course
,
course_run
):
def
send_email_for_course_creation
(
course
,
course_run
,
site
):
""" Send the emails for a course creation.
""" Send the emails for a course creation.
Arguments:
Arguments:
course (Course): Course object
course (Course): Course object
course_run (CourseRun): CourseRun object
course_run (CourseRun): CourseRun object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course_created.txt'
txt_template
=
'publisher/email/course_created.txt'
html_template
=
'publisher/email/course_created.html'
html_template
=
'publisher/email/course_created.html'
...
@@ -91,7 +92,7 @@ def send_email_for_course_creation(course, course_run):
...
@@ -91,7 +92,7 @@ def send_email_for_course_creation(course, course_run):
'course_team_name'
:
course_team
.
get_full_name
(),
'course_team_name'
:
course_team
.
get_full_name
(),
'project_coordinator_name'
:
project_coordinator
.
get_full_name
(),
'project_coordinator_name'
:
project_coordinator
.
get_full_name
(),
'dashboard_url'
:
'https://{host}{path}'
.
format
(
'dashboard_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
reverse
(
'publisher:publisher_dashboard'
)
host
=
site
.
domain
.
strip
(
'/'
),
path
=
reverse
(
'publisher:publisher_dashboard'
)
),
),
'from_address'
:
from_address
,
'from_address'
:
from_address
,
'contact_us_email'
:
project_coordinator
.
email
'contact_us_email'
:
project_coordinator
.
email
...
@@ -113,12 +114,13 @@ def send_email_for_course_creation(course, course_run):
...
@@ -113,12 +114,13 @@ def send_email_for_course_creation(course, course_run):
)
)
def
send_email_for_send_for_review
(
course
,
user
):
def
send_email_for_send_for_review
(
course
,
user
,
site
):
""" Send email when course is submitted for review.
""" Send email when course is submitted for review.
Arguments:
Arguments:
course (Object): Course object
course (Object): Course object
user (Object): User object
user (Object): User object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course/send_for_review.txt'
txt_template
=
'publisher/email/course/send_for_review.txt'
html_template
=
'publisher/email/course/send_for_review.html'
html_template
=
'publisher/email/course/send_for_review.html'
...
@@ -135,21 +137,22 @@ def send_email_for_send_for_review(course, user):
...
@@ -135,21 +137,22 @@ def send_email_for_send_for_review(course, user):
'course_name'
:
course
.
title
,
'course_name'
:
course
.
title
,
'sender_team'
:
'course team'
if
user_role
.
role
==
PublisherUserRole
.
CourseTeam
else
'marketing team'
,
'sender_team'
:
'course team'
if
user_role
.
role
==
PublisherUserRole
.
CourseTeam
else
'marketing team'
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
)
}
}
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
)
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
,
site
)
except
Exception
:
# pylint: disable=broad-except
except
Exception
:
# pylint: disable=broad-except
logger
.
exception
(
'Failed to send email notifications send for review of course
%
s'
,
course
.
id
)
logger
.
exception
(
'Failed to send email notifications send for review of course
%
s'
,
course
.
id
)
def
send_email_for_mark_as_reviewed
(
course
,
user
):
def
send_email_for_mark_as_reviewed
(
course
,
user
,
site
):
""" Send email when course is marked as reviewed.
""" Send email when course is marked as reviewed.
Arguments:
Arguments:
course (Object): Course object
course (Object): Course object
user (Object): User object
user (Object): User object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course/mark_as_reviewed.txt'
txt_template
=
'publisher/email/course/mark_as_reviewed.txt'
html_template
=
'publisher/email/course/mark_as_reviewed.html'
html_template
=
'publisher/email/course/mark_as_reviewed.html'
...
@@ -166,16 +169,16 @@ def send_email_for_mark_as_reviewed(course, user):
...
@@ -166,16 +169,16 @@ def send_email_for_mark_as_reviewed(course, user):
'course_name'
:
course
.
title
,
'course_name'
:
course
.
title
,
'sender_team'
:
'course team'
if
user_role
.
role
==
PublisherUserRole
.
CourseTeam
else
'marketing team'
,
'sender_team'
:
'course team'
if
user_role
.
role
==
PublisherUserRole
.
CourseTeam
else
'marketing team'
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
)
}
}
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
)
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
,
site
)
except
Exception
:
# pylint: disable=broad-except
except
Exception
:
# pylint: disable=broad-except
logger
.
exception
(
'Failed to send email notifications mark as reviewed of course
%
s'
,
course
.
id
)
logger
.
exception
(
'Failed to send email notifications mark as reviewed of course
%
s'
,
course
.
id
)
def
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
):
def
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
,
site
):
""" Send email for course workflow state change.
""" Send email for course workflow state change.
Arguments:
Arguments:
...
@@ -186,6 +189,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
...
@@ -186,6 +189,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
html_template: (String): Email html template path
html_template: (String): Email html template path
context: (Dict): Email template context
context: (Dict): Email template context
recipient_user: (Object): User object
recipient_user: (Object): User object
site (Site): Current site
"""
"""
if
is_email_notification_enabled
(
recipient_user
):
if
is_email_notification_enabled
(
recipient_user
):
...
@@ -202,7 +206,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
...
@@ -202,7 +206,7 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
'org_name'
:
course
.
organizations
.
all
()
.
first
()
.
name
,
'org_name'
:
course
.
organizations
.
all
()
.
first
()
.
name
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'course_page_url'
:
'https://{host}{path}'
.
format
(
'course_page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
course_page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
course_page_path
)
)
}
}
)
)
...
@@ -219,12 +223,13 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
...
@@ -219,12 +223,13 @@ def send_course_workflow_email(course, user, subject, txt_template, html_templat
email_msg
.
send
()
email_msg
.
send
()
def
send_email_for_send_for_review_course_run
(
course_run
,
user
):
def
send_email_for_send_for_review_course_run
(
course_run
,
user
,
site
):
""" Send email when course-run is submitted for review.
""" Send email when course-run is submitted for review.
Arguments:
Arguments:
course-run (Object): CourseRun object
course-run (Object): CourseRun object
user (Object): User object
user (Object): User object
site (Site): Current site
"""
"""
course
=
course_run
.
course
course
=
course_run
.
course
course_key
=
CourseKey
.
from_string
(
course_run
.
lms_course_id
)
course_key
=
CourseKey
.
from_string
(
course_run
.
lms_course_id
)
...
@@ -246,22 +251,23 @@ def send_email_for_send_for_review_course_run(course_run, user):
...
@@ -246,22 +251,23 @@ def send_email_for_send_for_review_course_run(course_run, user):
'run_number'
:
course_key
.
run
,
'run_number'
:
course_key
.
run
,
'sender_team'
:
'course team'
if
user_role
.
role
==
PublisherUserRole
.
CourseTeam
else
'project coordinators'
,
'sender_team'
:
'course team'
if
user_role
.
role
==
PublisherUserRole
.
CourseTeam
else
'project coordinators'
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
),
),
'studio_url'
:
course_run
.
studio_url
'studio_url'
:
course_run
.
studio_url
}
}
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
)
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
,
site
)
except
Exception
:
# pylint: disable=broad-except
except
Exception
:
# pylint: disable=broad-except
logger
.
exception
(
'Failed to send email notifications send for review of course-run
%
s'
,
course_run
.
id
)
logger
.
exception
(
'Failed to send email notifications send for review of course-run
%
s'
,
course_run
.
id
)
def
send_email_for_mark_as_reviewed_course_run
(
course_run
,
user
):
def
send_email_for_mark_as_reviewed_course_run
(
course_run
,
user
,
site
):
""" Send email when course-run is marked as reviewed.
""" Send email when course-run is marked as reviewed.
Arguments:
Arguments:
course_run (Object): CourseRun object
course_run (Object): CourseRun object
user (Object): User object
user (Object): User object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course_run/mark_as_reviewed.txt'
txt_template
=
'publisher/email/course_run/mark_as_reviewed.txt'
html_template
=
'publisher/email/course_run/mark_as_reviewed.html'
html_template
=
'publisher/email/course_run/mark_as_reviewed.html'
...
@@ -284,21 +290,24 @@ def send_email_for_mark_as_reviewed_course_run(course_run, user):
...
@@ -284,21 +290,24 @@ def send_email_for_mark_as_reviewed_course_run(course_run, user):
'run_number'
:
course_key
.
run
,
'run_number'
:
course_key
.
run
,
'sender_team'
:
'course team'
,
'sender_team'
:
'course team'
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
)
}
}
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
)
send_course_workflow_email
(
course
,
user
,
subject
,
txt_template
,
html_template
,
context
,
recipient_user
,
site
)
except
Exception
:
# pylint: disable=broad-except
except
Exception
:
# pylint: disable=broad-except
logger
.
exception
(
'Failed to send email notifications for mark as reviewed of course-run
%
s'
,
course_run
.
id
)
logger
.
exception
(
'Failed to send email notifications for mark as reviewed of course-run
%
s'
,
course_run
.
id
)
def
send_email_to_publisher
(
course_run
,
user
):
def
send_email_to_publisher
(
course_run
,
user
,
site
):
""" Send email to publisher when course-run is marked as reviewed.
""" Send email to publisher when course-run is marked as reviewed.
Arguments:
Arguments:
course_run (Object): CourseRun object
course_run (Object): CourseRun object
user (Object): User object
user (Object): User object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course_run/mark_as_reviewed.txt'
txt_template
=
'publisher/email/course_run/mark_as_reviewed.txt'
html_template
=
'publisher/email/course_run/mark_as_reviewed.html'
html_template
=
'publisher/email/course_run/mark_as_reviewed.html'
...
@@ -330,7 +339,7 @@ def send_email_to_publisher(course_run, user):
...
@@ -330,7 +339,7 @@ def send_email_to_publisher(course_run, user):
'sender_team'
:
sender_team
,
'sender_team'
:
sender_team
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
)
}
}
...
@@ -349,11 +358,12 @@ def send_email_to_publisher(course_run, user):
...
@@ -349,11 +358,12 @@ def send_email_to_publisher(course_run, user):
logger
.
exception
(
'Failed to send email notifications for mark as reviewed of course-run
%
s'
,
course_run
.
id
)
logger
.
exception
(
'Failed to send email notifications for mark as reviewed of course-run
%
s'
,
course_run
.
id
)
def
send_email_preview_accepted
(
course_run
):
def
send_email_preview_accepted
(
course_run
,
site
):
""" Send email for preview approved to publisher and project coordinator.
""" Send email for preview approved to publisher and project coordinator.
Arguments:
Arguments:
course_run (Object): CourseRun object
course_run (Object): CourseRun object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course_run/preview_accepted.txt'
txt_template
=
'publisher/email/course_run/preview_accepted.txt'
html_template
=
'publisher/email/course_run/preview_accepted.html'
html_template
=
'publisher/email/course_run/preview_accepted.html'
...
@@ -382,10 +392,10 @@ def send_email_preview_accepted(course_run):
...
@@ -382,10 +392,10 @@ def send_email_preview_accepted(course_run):
'org_name'
:
course
.
organizations
.
all
()
.
first
()
.
name
,
'org_name'
:
course
.
organizations
.
all
()
.
first
()
.
name
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
),
),
'course_page_url'
:
'https://{host}{path}'
.
format
(
'course_page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
course_page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
course_page_path
)
)
}
}
template
=
get_template
(
txt_template
)
template
=
get_template
(
txt_template
)
...
@@ -406,11 +416,12 @@ def send_email_preview_accepted(course_run):
...
@@ -406,11 +416,12 @@ def send_email_preview_accepted(course_run):
raise
Exception
(
message
)
raise
Exception
(
message
)
def
send_email_preview_page_is_available
(
course_run
):
def
send_email_preview_page_is_available
(
course_run
,
site
):
""" Send email for course preview available to course team.
""" Send email for course preview available to course team.
Arguments:
Arguments:
course_run (Object): CourseRun object
course_run (Object): CourseRun object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course_run/preview_available.txt'
txt_template
=
'publisher/email/course_run/preview_available.txt'
html_template
=
'publisher/email/course_run/preview_available.html'
html_template
=
'publisher/email/course_run/preview_available.html'
...
@@ -436,10 +447,10 @@ def send_email_preview_page_is_available(course_run):
...
@@ -436,10 +447,10 @@ def send_email_preview_page_is_available(course_run):
'preview_link'
:
course_run
.
preview_url
,
'preview_link'
:
course_run
.
preview_url
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
),
),
'course_page_url'
:
'https://{host}{path}'
.
format
(
'course_page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
course_page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
course_page_path
),
),
'platform_name'
:
settings
.
PLATFORM_NAME
'platform_name'
:
settings
.
PLATFORM_NAME
}
}
...
@@ -462,11 +473,12 @@ def send_email_preview_page_is_available(course_run):
...
@@ -462,11 +473,12 @@ def send_email_preview_page_is_available(course_run):
raise
Exception
(
error_message
)
raise
Exception
(
error_message
)
def
send_course_run_published_email
(
course_run
):
def
send_course_run_published_email
(
course_run
,
site
):
""" Send email when course run is published by publisher.
""" Send email when course run is published by publisher.
Arguments:
Arguments:
course_run (Object): CourseRun object
course_run (Object): CourseRun object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course_run/published.txt'
txt_template
=
'publisher/email/course_run/published.txt'
html_template
=
'publisher/email/course_run/published.html'
html_template
=
'publisher/email/course_run/published.html'
...
@@ -492,10 +504,10 @@ def send_course_run_published_email(course_run):
...
@@ -492,10 +504,10 @@ def send_course_run_published_email(course_run):
'recipient_name'
:
course_team_user
.
get_full_name
()
or
course_team_user
.
username
,
'recipient_name'
:
course_team_user
.
get_full_name
()
or
course_team_user
.
username
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'page_url'
:
'https://{host}{path}'
.
format
(
'page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
),
),
'course_page_url'
:
'https://{host}{path}'
.
format
(
'course_page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
course_page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
course_page_path
),
),
'platform_name'
:
settings
.
PLATFORM_NAME
,
'platform_name'
:
settings
.
PLATFORM_NAME
,
}
}
...
@@ -518,12 +530,13 @@ def send_course_run_published_email(course_run):
...
@@ -518,12 +530,13 @@ def send_course_run_published_email(course_run):
raise
Exception
(
error_message
)
raise
Exception
(
error_message
)
def
send_change_role_assignment_email
(
course_role
,
former_user
):
def
send_change_role_assignment_email
(
course_role
,
former_user
,
site
):
""" Send email for role assignment changed.
""" Send email for role assignment changed.
Arguments:
Arguments:
course_role (Object): CourseUserRole object
course_role (Object): CourseUserRole object
former_user (Object): User object
former_user (Object): User object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/role_assignment_changed.txt'
txt_template
=
'publisher/email/role_assignment_changed.txt'
html_template
=
'publisher/email/role_assignment_changed.html'
html_template
=
'publisher/email/role_assignment_changed.html'
...
@@ -549,7 +562,7 @@ def send_change_role_assignment_email(course_role, former_user):
...
@@ -549,7 +562,7 @@ def send_change_role_assignment_email(course_role, former_user):
'current_user_name'
:
course_role
.
user
.
get_full_name
()
or
course_role
.
user
.
username
,
'current_user_name'
:
course_role
.
user
.
get_full_name
()
or
course_role
.
user
.
username
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'contact_us_email'
:
project_coordinator
.
email
if
project_coordinator
else
''
,
'course_url'
:
'https://{host}{path}'
.
format
(
'course_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
page_path
),
),
'platform_name'
:
settings
.
PLATFORM_NAME
,
'platform_name'
:
settings
.
PLATFORM_NAME
,
}
}
...
@@ -572,11 +585,12 @@ def send_change_role_assignment_email(course_role, former_user):
...
@@ -572,11 +585,12 @@ def send_change_role_assignment_email(course_role, former_user):
raise
Exception
(
error_message
)
raise
Exception
(
error_message
)
def
send_email_for_seo_review
(
course
):
def
send_email_for_seo_review
(
course
,
site
):
""" Send email when course is submitted for seo review.
""" Send email when course is submitted for seo review.
Arguments:
Arguments:
course (Object): Course object
course (Object): Course object
site (Site): Current site
"""
"""
txt_template
=
'publisher/email/course/seo_review.txt'
txt_template
=
'publisher/email/course/seo_review.txt'
html_template
=
'publisher/email/course/seo_review.html'
html_template
=
'publisher/email/course/seo_review.html'
...
@@ -597,7 +611,7 @@ def send_email_for_seo_review(course):
...
@@ -597,7 +611,7 @@ def send_email_for_seo_review(course):
'org_name'
:
course
.
organizations
.
all
()
.
first
()
.
name
,
'org_name'
:
course
.
organizations
.
all
()
.
first
()
.
name
,
'contact_us_email'
:
project_coordinator
.
email
,
'contact_us_email'
:
project_coordinator
.
email
,
'course_page_url'
:
'https://{host}{path}'
.
format
(
'course_page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
course_page_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
course_page_path
)
)
}
}
...
@@ -615,11 +629,12 @@ def send_email_for_seo_review(course):
...
@@ -615,11 +629,12 @@ def send_email_for_seo_review(course):
logger
.
exception
(
'Failed to send email notifications for legal review requested of course
%
s'
,
course
.
id
)
logger
.
exception
(
'Failed to send email notifications for legal review requested of course
%
s'
,
course
.
id
)
def
send_email_for_published_course_run_editing
(
course_run
):
def
send_email_for_published_course_run_editing
(
course_run
,
site
):
""" Send email when published course-run is edited.
""" Send email when published course-run is edited.
Arguments:
Arguments:
course-run (Object): Course Run object
course-run (Object): Course Run object
site (Site): Current site
"""
"""
try
:
try
:
course
=
course_run
.
course
course
=
course_run
.
course
...
@@ -644,7 +659,7 @@ def send_email_for_published_course_run_editing(course_run):
...
@@ -644,7 +659,7 @@ def send_email_for_published_course_run_editing(course_run):
'recipient_name'
:
publisher_user
.
get_full_name
()
or
publisher_user
.
username
,
'recipient_name'
:
publisher_user
.
get_full_name
()
or
publisher_user
.
username
,
'contact_us_email'
:
course
.
project_coordinator
.
email
,
'contact_us_email'
:
course
.
project_coordinator
.
email
,
'course_run_page_url'
:
'https://{host}{path}'
.
format
(
'course_run_page_url'
:
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
object_path
host
=
site
.
domain
.
strip
(
'/'
),
path
=
object_path
),
),
'course_run_number'
:
course_key
.
run
,
'course_run_number'
:
course_key
.
run
,
}
}
...
...
course_discovery/apps/publisher/models.py
View file @
e251012e
...
@@ -617,7 +617,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
...
@@ -617,7 +617,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
# TODO: send email etc.
# TODO: send email etc.
pass
pass
def
change_state
(
self
,
state
,
user
):
def
change_state
(
self
,
state
,
user
,
site
=
None
):
"""
"""
Change course workflow state and ownership also send emails if required.
Change course workflow state and ownership also send emails if required.
"""
"""
...
@@ -632,12 +632,12 @@ class CourseState(TimeStampedModel, ChangedByMixin):
...
@@ -632,12 +632,12 @@ class CourseState(TimeStampedModel, ChangedByMixin):
elif
user_role
.
role
==
PublisherUserRole
.
CourseTeam
:
elif
user_role
.
role
==
PublisherUserRole
.
CourseTeam
:
self
.
change_owner_role
(
PublisherUserRole
.
MarketingReviewer
)
self
.
change_owner_role
(
PublisherUserRole
.
MarketingReviewer
)
if
is_notifications_enabled
:
if
is_notifications_enabled
:
emails
.
send_email_for_seo_review
(
self
.
course
)
emails
.
send_email_for_seo_review
(
self
.
course
,
site
)
self
.
review
()
self
.
review
()
if
is_notifications_enabled
:
if
is_notifications_enabled
:
emails
.
send_email_for_send_for_review
(
self
.
course
,
user
)
emails
.
send_email_for_send_for_review
(
self
.
course
,
user
,
site
)
elif
state
==
CourseStateChoices
.
Approved
:
elif
state
==
CourseStateChoices
.
Approved
:
user_role
=
self
.
course
.
course_user_roles
.
get
(
user
=
user
)
user_role
=
self
.
course
.
course_user_roles
.
get
(
user
=
user
)
...
@@ -646,7 +646,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
...
@@ -646,7 +646,7 @@ class CourseState(TimeStampedModel, ChangedByMixin):
self
.
approved
()
self
.
approved
()
if
is_notifications_enabled
:
if
is_notifications_enabled
:
emails
.
send_email_for_mark_as_reviewed
(
self
.
course
,
user
)
emails
.
send_email_for_mark_as_reviewed
(
self
.
course
,
user
,
site
)
self
.
save
()
self
.
save
()
...
@@ -744,10 +744,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
...
@@ -744,10 +744,10 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
pass
pass
@transition
(
field
=
name
,
source
=
CourseRunStateChoices
.
Approved
,
target
=
CourseRunStateChoices
.
Published
)
@transition
(
field
=
name
,
source
=
CourseRunStateChoices
.
Approved
,
target
=
CourseRunStateChoices
.
Published
)
def
published
(
self
):
def
published
(
self
,
site
):
emails
.
send_course_run_published_email
(
self
.
course_run
)
emails
.
send_course_run_published_email
(
self
.
course_run
,
site
)
def
change_state
(
self
,
state
,
user
):
def
change_state
(
self
,
state
,
user
,
site
=
None
):
"""
"""
Change course run workflow state and ownership also send emails if required.
Change course run workflow state and ownership also send emails if required.
"""
"""
...
@@ -763,7 +763,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
...
@@ -763,7 +763,7 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self
.
review
()
self
.
review
()
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run
,
user
)
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run
,
user
,
site
)
elif
state
==
CourseRunStateChoices
.
Approved
:
elif
state
==
CourseRunStateChoices
.
Approved
:
user_role
=
self
.
course_run
.
course
.
course_user_roles
.
get
(
user
=
user
)
user_role
=
self
.
course_run
.
course
.
course_user_roles
.
get
(
user
=
user
)
...
@@ -772,11 +772,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
...
@@ -772,11 +772,11 @@ class CourseRunState(TimeStampedModel, ChangedByMixin):
self
.
approved
()
self
.
approved
()
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
if
waffle
.
switch_is_active
(
'enable_publisher_email_notifications'
):
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run
,
user
)
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run
,
user
,
site
)
emails
.
send_email_to_publisher
(
self
.
course_run
,
user
)
emails
.
send_email_to_publisher
(
self
.
course_run
,
user
,
site
)
elif
state
==
CourseRunStateChoices
.
Published
:
elif
state
==
CourseRunStateChoices
.
Published
:
self
.
published
()
self
.
published
(
site
)
self
.
save
()
self
.
save
()
...
...
course_discovery/apps/publisher/tests/test_admin.py
View file @
e251012e
...
@@ -4,6 +4,7 @@ from django.test import TestCase
...
@@ -4,6 +4,7 @@ from django.test import TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
guardian.shortcuts
import
get_group_perms
from
guardian.shortcuts
import
get_group_perms
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.course_metadata.tests.factories
import
OrganizationFactory
from
course_discovery.apps.course_metadata.tests.factories
import
OrganizationFactory
from
course_discovery.apps.publisher.choices
import
PublisherUserRole
from
course_discovery.apps.publisher.choices
import
PublisherUserRole
...
@@ -18,7 +19,7 @@ USER_PASSWORD = 'password'
...
@@ -18,7 +19,7 @@ USER_PASSWORD = 'password'
# pylint: disable=no-member
# pylint: disable=no-member
class
AdminTests
(
TestCase
):
class
AdminTests
(
SiteMixin
,
TestCase
):
""" Tests Admin page."""
""" Tests Admin page."""
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -81,7 +82,7 @@ class AdminTests(TestCase):
...
@@ -81,7 +82,7 @@ class AdminTests(TestCase):
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
class
OrganizationExtensionAdminTests
(
TestCase
):
class
OrganizationExtensionAdminTests
(
SiteMixin
,
TestCase
):
""" Tests for OrganizationExtensionAdmin."""
""" Tests for OrganizationExtensionAdmin."""
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -134,7 +135,7 @@ class OrganizationExtensionAdminTests(TestCase):
...
@@ -134,7 +135,7 @@ class OrganizationExtensionAdminTests(TestCase):
@ddt.ddt
@ddt.ddt
class
OrganizationUserRoleAdminTests
(
TestCase
):
class
OrganizationUserRoleAdminTests
(
SiteMixin
,
TestCase
):
""" Tests for OrganizationUserRoleAdmin."""
""" Tests for OrganizationUserRoleAdmin."""
def
setUp
(
self
):
def
setUp
(
self
):
...
...
course_discovery/apps/publisher/tests/test_emails.py
View file @
e251012e
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
import
mock
import
mock
from
django.contrib.auth.models
import
Group
from
django.contrib.auth.models
import
Group
from
django.contrib.sites.models
import
Site
from
django.core
import
mail
from
django.core
import
mail
from
django.test
import
TestCase
from
django.test
import
TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
opaque_keys.edx.keys
import
CourseKey
from
opaque_keys.edx.keys
import
CourseKey
from
testfixtures
import
LogCapture
from
testfixtures
import
LogCapture
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
...
@@ -21,7 +21,7 @@ from course_discovery.apps.publisher.tests import factories
...
@@ -21,7 +21,7 @@ from course_discovery.apps.publisher.tests import factories
from
course_discovery.apps.publisher.tests.factories
import
UserAttributeFactory
from
course_discovery.apps.publisher.tests.factories
import
UserAttributeFactory
class
StudioInstanceCreatedEmailTests
(
TestCase
):
class
StudioInstanceCreatedEmailTests
(
SiteMixin
,
TestCase
):
"""
"""
Tests for the studio instance created email functionality.
Tests for the studio instance created email functionality.
"""
"""
...
@@ -50,14 +50,14 @@ class StudioInstanceCreatedEmailTests(TestCase):
...
@@ -50,14 +50,14 @@ class StudioInstanceCreatedEmailTests(TestCase):
""" Verify that emails failure raise exception."""
""" Verify that emails failure raise exception."""
with
self
.
assertRaises
(
Exception
)
as
ex
:
with
self
.
assertRaises
(
Exception
)
as
ex
:
emails
.
send_email_for_studio_instance_created
(
self
.
course_run
)
emails
.
send_email_for_studio_instance_created
(
self
.
course_run
,
self
.
site
)
error_message
=
'Failed to send email notifications for course_run [{}]'
.
format
(
self
.
course_run
.
id
)
error_message
=
'Failed to send email notifications for course_run [{}]'
.
format
(
self
.
course_run
.
id
)
self
.
assertEqual
(
ex
.
message
,
error_message
)
self
.
assertEqual
(
ex
.
message
,
error_message
)
def
test_email_sent_successfully
(
self
):
def
test_email_sent_successfully
(
self
):
""" Verify that emails sent successfully for studio instance created."""
""" Verify that emails sent successfully for studio instance created."""
emails
.
send_email_for_studio_instance_created
(
self
.
course_run
)
emails
.
send_email_for_studio_instance_created
(
self
.
course_run
,
self
.
site
)
course_key
=
CourseKey
.
from_string
(
self
.
course_run
.
lms_course_id
)
course_key
=
CourseKey
.
from_string
(
self
.
course_run
.
lms_course_id
)
self
.
assert_email_sent
(
self
.
assert_email_sent
(
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
course_run
.
id
}),
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
course_run
.
id
}),
...
@@ -76,7 +76,7 @@ class StudioInstanceCreatedEmailTests(TestCase):
...
@@ -76,7 +76,7 @@ class StudioInstanceCreatedEmailTests(TestCase):
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
self
.
assertIn
(
expected_body
,
body
)
self
.
assertIn
(
expected_body
,
body
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
object_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
object_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
'Enter course run content in Studio.'
,
body
)
self
.
assertIn
(
'Enter course run content in Studio.'
,
body
)
self
.
assertIn
(
'Thanks'
,
body
)
self
.
assertIn
(
'Thanks'
,
body
)
...
@@ -89,7 +89,7 @@ class StudioInstanceCreatedEmailTests(TestCase):
...
@@ -89,7 +89,7 @@ class StudioInstanceCreatedEmailTests(TestCase):
)
)
class
CourseCreatedEmailTests
(
TestCase
):
class
CourseCreatedEmailTests
(
SiteMixin
,
TestCase
):
""" Tests for the new course created email functionality. """
""" Tests for the new course created email functionality. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -116,7 +116,7 @@ class CourseCreatedEmailTests(TestCase):
...
@@ -116,7 +116,7 @@ class CourseCreatedEmailTests(TestCase):
""" Verify that emails failure logs error message."""
""" Verify that emails failure logs error message."""
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_for_course_creation
(
self
.
course_run
.
course
,
self
.
course_run
)
emails
.
send_email_for_course_creation
(
self
.
course_run
.
course
,
self
.
course_run
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -130,7 +130,7 @@ class CourseCreatedEmailTests(TestCase):
...
@@ -130,7 +130,7 @@ class CourseCreatedEmailTests(TestCase):
def
test_email_sent_successfully
(
self
):
def
test_email_sent_successfully
(
self
):
""" Verify that studio instance request email sent successfully."""
""" Verify that studio instance request email sent successfully."""
emails
.
send_email_for_course_creation
(
self
.
course_run
.
course
,
self
.
course_run
)
emails
.
send_email_for_course_creation
(
self
.
course_run
.
course
,
self
.
course_run
,
self
.
site
)
subject
=
'Studio URL requested: {title}'
.
format
(
title
=
self
.
course_run
.
course
.
title
)
subject
=
'Studio URL requested: {title}'
.
format
(
title
=
self
.
course_run
.
course
.
title
)
self
.
assert_email_sent
(
subject
)
self
.
assert_email_sent
(
subject
)
...
@@ -151,12 +151,12 @@ class CourseCreatedEmailTests(TestCase):
...
@@ -151,12 +151,12 @@ class CourseCreatedEmailTests(TestCase):
user_attribute
=
UserAttributes
.
objects
.
get
(
user
=
self
.
user
)
user_attribute
=
UserAttributes
.
objects
.
get
(
user
=
self
.
user
)
user_attribute
.
enable_email_notification
=
False
user_attribute
.
enable_email_notification
=
False
user_attribute
.
save
()
user_attribute
.
save
()
emails
.
send_email_for_course_creation
(
self
.
course_run
.
course
,
self
.
course_run
)
emails
.
send_email_for_course_creation
(
self
.
course_run
.
course
,
self
.
course_run
,
self
.
site
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
class
SendForReviewEmailTests
(
TestCase
):
class
SendForReviewEmailTests
(
SiteMixin
,
TestCase
):
""" Tests for the send for review email functionality. """
""" Tests for the send for review email functionality. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -168,7 +168,7 @@ class SendForReviewEmailTests(TestCase):
...
@@ -168,7 +168,7 @@ class SendForReviewEmailTests(TestCase):
""" Verify that email failure logs error message."""
""" Verify that email failure logs error message."""
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_for_send_for_review
(
self
.
course_state
.
course
,
self
.
user
)
emails
.
send_email_for_send_for_review
(
self
.
course_state
.
course
,
self
.
user
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -180,7 +180,7 @@ class SendForReviewEmailTests(TestCase):
...
@@ -180,7 +180,7 @@ class SendForReviewEmailTests(TestCase):
)
)
class
CourseMarkAsReviewedEmailTests
(
TestCase
):
class
CourseMarkAsReviewedEmailTests
(
SiteMixin
,
TestCase
):
""" Tests for the mark as reviewed email functionality. """
""" Tests for the mark as reviewed email functionality. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -192,7 +192,7 @@ class CourseMarkAsReviewedEmailTests(TestCase):
...
@@ -192,7 +192,7 @@ class CourseMarkAsReviewedEmailTests(TestCase):
""" Verify that email failure logs error message."""
""" Verify that email failure logs error message."""
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_for_mark_as_reviewed
(
self
.
course_state
.
course
,
self
.
user
)
emails
.
send_email_for_mark_as_reviewed
(
self
.
course_state
.
course
,
self
.
user
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -204,7 +204,7 @@ class CourseMarkAsReviewedEmailTests(TestCase):
...
@@ -204,7 +204,7 @@ class CourseMarkAsReviewedEmailTests(TestCase):
)
)
class
CourseRunSendForReviewEmailTests
(
TestCase
):
class
CourseRunSendForReviewEmailTests
(
SiteMixin
,
TestCase
):
""" Tests for the CourseRun send for review email functionality. """
""" Tests for the CourseRun send for review email functionality. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -238,7 +238,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
...
@@ -238,7 +238,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
factories
.
CourseUserRoleFactory
(
factories
.
CourseUserRoleFactory
(
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
)
)
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user
)
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user
,
self
.
site
)
subject
=
'Review requested: {title} {run_number}'
.
format
(
title
=
self
.
course
,
run_number
=
self
.
course_key
.
run
)
subject
=
'Review requested: {title} {run_number}'
.
format
(
title
=
self
.
course
,
run_number
=
self
.
course_key
.
run
)
self
.
assert_email_sent
(
subject
,
self
.
user_2
)
self
.
assert_email_sent
(
subject
,
self
.
user_2
)
...
@@ -247,7 +247,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
...
@@ -247,7 +247,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
factories
.
CourseUserRoleFactory
(
factories
.
CourseUserRoleFactory
(
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
)
)
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user_2
)
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user_2
,
self
.
site
)
subject
=
'Review requested: {title} {run_number}'
.
format
(
title
=
self
.
course
,
run_number
=
self
.
course_key
.
run
)
subject
=
'Review requested: {title} {run_number}'
.
format
(
title
=
self
.
course
,
run_number
=
self
.
course_key
.
run
)
self
.
assert_email_sent
(
subject
,
self
.
user
)
self
.
assert_email_sent
(
subject
,
self
.
user
)
...
@@ -255,7 +255,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
...
@@ -255,7 +255,7 @@ class CourseRunSendForReviewEmailTests(TestCase):
""" Verify that email failure logs error message."""
""" Verify that email failure logs error message."""
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run
,
self
.
user
)
emails
.
send_email_for_send_for_review_course_run
(
self
.
course_run
,
self
.
user
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -273,12 +273,12 @@ class CourseRunSendForReviewEmailTests(TestCase):
...
@@ -273,12 +273,12 @@ class CourseRunSendForReviewEmailTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
course_run
.
id
})
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
course_run
.
id
})
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
'View this course run in Publisher to review the changes or suggest edits.'
,
body
)
self
.
assertIn
(
'View this course run in Publisher to review the changes or suggest edits.'
,
body
)
class
CourseRunMarkAsReviewedEmailTests
(
TestCase
):
class
CourseRunMarkAsReviewedEmailTests
(
SiteMixin
,
TestCase
):
""" Tests for the CourseRun mark as reviewed email functionality. """
""" Tests for the CourseRun mark as reviewed email functionality. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -311,7 +311,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
...
@@ -311,7 +311,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
factories
.
CourseUserRoleFactory
(
factories
.
CourseUserRoleFactory
(
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
)
)
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user
)
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user
,
self
.
site
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
def
test_email_sent_by_course_team
(
self
):
def
test_email_sent_by_course_team
(
self
):
...
@@ -319,14 +319,14 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
...
@@ -319,14 +319,14 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
factories
.
CourseUserRoleFactory
(
factories
.
CourseUserRoleFactory
(
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
)
)
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user_2
)
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run_state
.
course_run
,
self
.
user_2
,
self
.
site
)
self
.
assert_email_sent
(
self
.
user
)
self
.
assert_email_sent
(
self
.
user
)
def
test_email_mark_as_reviewed_with_error
(
self
):
def
test_email_mark_as_reviewed_with_error
(
self
):
""" Verify that email failure log error message."""
""" Verify that email failure log error message."""
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run
,
self
.
user
)
emails
.
send_email_for_mark_as_reviewed_course_run
(
self
.
course_run
,
self
.
user
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -342,7 +342,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
...
@@ -342,7 +342,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
factories
.
CourseUserRoleFactory
(
factories
.
CourseUserRoleFactory
(
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
,
user
=
self
.
user
)
)
emails
.
send_email_to_publisher
(
self
.
course_run_state
.
course_run
,
self
.
user
)
emails
.
send_email_to_publisher
(
self
.
course_run_state
.
course_run
,
self
.
user
,
self
.
site
)
self
.
assert_email_sent
(
self
.
user_3
)
self
.
assert_email_sent
(
self
.
user_3
)
def
test_email_to_publisher_with_error
(
self
):
def
test_email_to_publisher_with_error
(
self
):
...
@@ -350,7 +350,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
...
@@ -350,7 +350,7 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
with
mock
.
patch
(
'django.core.mail.message.EmailMessage.send'
,
side_effect
=
TypeError
):
with
mock
.
patch
(
'django.core.mail.message.EmailMessage.send'
,
side_effect
=
TypeError
):
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_to_publisher
(
self
.
course_run
,
self
.
user_3
)
emails
.
send_email_to_publisher
(
self
.
course_run
,
self
.
user_3
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -375,12 +375,12 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
...
@@ -375,12 +375,12 @@ class CourseRunMarkAsReviewedEmailTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
course_run
.
id
})
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
course_run
.
id
})
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
'The review for this course run is complete.'
,
body
)
self
.
assertIn
(
'The review for this course run is complete.'
,
body
)
class
CourseRunPreviewEmailTests
(
TestCase
):
class
CourseRunPreviewEmailTests
(
SiteMixin
,
TestCase
):
"""
"""
Tests for the course preview email functionality.
Tests for the course preview email functionality.
"""
"""
...
@@ -414,7 +414,7 @@ class CourseRunPreviewEmailTests(TestCase):
...
@@ -414,7 +414,7 @@ class CourseRunPreviewEmailTests(TestCase):
lms_course_id
=
'course-v1:edX+DemoX+Demo_Course'
lms_course_id
=
'course-v1:edX+DemoX+Demo_Course'
self
.
run_state
.
course_run
.
lms_course_id
=
lms_course_id
self
.
run_state
.
course_run
.
lms_course_id
=
lms_course_id
emails
.
send_email_preview_accepted
(
self
.
run_state
.
course_run
)
emails
.
send_email_preview_accepted
(
self
.
run_state
.
course_run
,
self
.
site
)
course_key
=
CourseKey
.
from_string
(
lms_course_id
)
course_key
=
CourseKey
.
from_string
(
lms_course_id
)
subject
=
'Publication requested: {course_name} {run_number}'
.
format
(
subject
=
'Publication requested: {course_name} {run_number}'
.
format
(
...
@@ -426,7 +426,7 @@ class CourseRunPreviewEmailTests(TestCase):
...
@@ -426,7 +426,7 @@ class CourseRunPreviewEmailTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
run_state
.
course_run
.
id
})
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
self
.
run_state
.
course_run
.
id
})
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
'You can now publish this About page.'
,
body
)
self
.
assertIn
(
'You can now publish this About page.'
,
body
)
...
@@ -440,7 +440,7 @@ class CourseRunPreviewEmailTests(TestCase):
...
@@ -440,7 +440,7 @@ class CourseRunPreviewEmailTests(TestCase):
with
self
.
assertRaises
(
Exception
)
as
ex
:
with
self
.
assertRaises
(
Exception
)
as
ex
:
self
.
assertEqual
(
str
(
ex
.
exception
),
message
)
self
.
assertEqual
(
str
(
ex
.
exception
),
message
)
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_preview_accepted
(
self
.
run_state
.
course_run
)
emails
.
send_email_preview_accepted
(
self
.
run_state
.
course_run
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -457,7 +457,7 @@ class CourseRunPreviewEmailTests(TestCase):
...
@@ -457,7 +457,7 @@ class CourseRunPreviewEmailTests(TestCase):
course_run
.
lms_course_id
=
'course-v1:testX+testX1.0+2017T1'
course_run
.
lms_course_id
=
'course-v1:testX+testX1.0+2017T1'
course_run
.
save
()
course_run
.
save
()
emails
.
send_email_preview_page_is_available
(
course_run
)
emails
.
send_email_preview_page_is_available
(
course_run
,
self
.
site
)
course_key
=
CourseKey
.
from_string
(
course_run
.
lms_course_id
)
course_key
=
CourseKey
.
from_string
(
course_run
.
lms_course_id
)
subject
=
'Review requested: Preview for {course_name} {run_number}'
.
format
(
subject
=
'Review requested: Preview for {course_name} {run_number}'
.
format
(
...
@@ -469,7 +469,7 @@ class CourseRunPreviewEmailTests(TestCase):
...
@@ -469,7 +469,7 @@ class CourseRunPreviewEmailTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
subject
)
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
course_run
.
id
})
page_path
=
reverse
(
'publisher:publisher_course_run_detail'
,
kwargs
=
{
'pk'
:
course_run
.
id
})
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
'A preview is now available for the'
,
body
)
self
.
assertIn
(
'A preview is now available for the'
,
body
)
...
@@ -477,7 +477,7 @@ class CourseRunPreviewEmailTests(TestCase):
...
@@ -477,7 +477,7 @@ class CourseRunPreviewEmailTests(TestCase):
""" Verify that exception raised on email failure."""
""" Verify that exception raised on email failure."""
with
self
.
assertRaises
(
Exception
)
as
ex
:
with
self
.
assertRaises
(
Exception
)
as
ex
:
emails
.
send_email_preview_page_is_available
(
self
.
run_state
.
course_run
)
emails
.
send_email_preview_page_is_available
(
self
.
run_state
.
course_run
,
self
.
site
)
error_message
=
'Failed to send email notifications for preview available of course-run {}'
.
format
(
error_message
=
'Failed to send email notifications for preview available of course-run {}'
.
format
(
self
.
run_state
.
course_run
.
id
self
.
run_state
.
course_run
.
id
)
)
...
@@ -486,19 +486,19 @@ class CourseRunPreviewEmailTests(TestCase):
...
@@ -486,19 +486,19 @@ class CourseRunPreviewEmailTests(TestCase):
def
test_preview_available_email_with_notification_disabled
(
self
):
def
test_preview_available_email_with_notification_disabled
(
self
):
""" Verify that email not sent if notification disabled by user."""
""" Verify that email not sent if notification disabled by user."""
factories
.
UserAttributeFactory
(
user
=
self
.
course
.
course_team_admin
,
enable_email_notification
=
False
)
factories
.
UserAttributeFactory
(
user
=
self
.
course
.
course_team_admin
,
enable_email_notification
=
False
)
emails
.
send_email_preview_page_is_available
(
self
.
run_state
.
course_run
)
emails
.
send_email_preview_page_is_available
(
self
.
run_state
.
course_run
,
self
.
site
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
def
test_preview_accepted_email_with_notification_disabled
(
self
):
def
test_preview_accepted_email_with_notification_disabled
(
self
):
""" Verify that preview accepted email not sent if notification disabled by user."""
""" Verify that preview accepted email not sent if notification disabled by user."""
factories
.
UserAttributeFactory
(
user
=
self
.
course
.
publisher
,
enable_email_notification
=
False
)
factories
.
UserAttributeFactory
(
user
=
self
.
course
.
publisher
,
enable_email_notification
=
False
)
emails
.
send_email_preview_accepted
(
self
.
run_state
.
course_run
)
emails
.
send_email_preview_accepted
(
self
.
run_state
.
course_run
,
self
.
site
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
0
)
class
CourseRunPublishedEmailTests
(
TestCase
):
class
CourseRunPublishedEmailTests
(
SiteMixin
,
TestCase
):
"""
"""
Tests for course run published email functionality.
Tests for course run published email functionality.
"""
"""
...
@@ -527,7 +527,7 @@ class CourseRunPublishedEmailTests(TestCase):
...
@@ -527,7 +527,7 @@ class CourseRunPublishedEmailTests(TestCase):
"""
"""
self
.
course_run
.
lms_course_id
=
'course-v1:testX+test45+2017T2'
self
.
course_run
.
lms_course_id
=
'course-v1:testX+test45+2017T2'
self
.
course_run
.
save
()
self
.
course_run
.
save
()
emails
.
send_course_run_published_email
(
self
.
course_run
)
emails
.
send_course_run_published_email
(
self
.
course_run
,
self
.
site
)
course_key
=
CourseKey
.
from_string
(
self
.
course_run
.
lms_course_id
)
course_key
=
CourseKey
.
from_string
(
self
.
course_run
.
lms_course_id
)
subject
=
'Publication complete: About page for {course_name} {run_number}'
.
format
(
subject
=
'Publication complete: About page for {course_name} {run_number}'
.
format
(
...
@@ -550,11 +550,11 @@ class CourseRunPublishedEmailTests(TestCase):
...
@@ -550,11 +550,11 @@ class CourseRunPublishedEmailTests(TestCase):
)
)
with
mock
.
patch
(
'django.core.mail.message.EmailMessage.send'
,
side_effect
=
TypeError
):
with
mock
.
patch
(
'django.core.mail.message.EmailMessage.send'
,
side_effect
=
TypeError
):
with
self
.
assertRaises
(
Exception
)
as
ex
:
with
self
.
assertRaises
(
Exception
)
as
ex
:
emails
.
send_course_run_published_email
(
self
.
course_run
)
emails
.
send_course_run_published_email
(
self
.
course_run
,
self
.
site
)
self
.
assertEqual
(
str
(
ex
.
exception
),
message
)
self
.
assertEqual
(
str
(
ex
.
exception
),
message
)
class
CourseChangeRoleAssignmentEmailTests
(
TestCase
):
class
CourseChangeRoleAssignmentEmailTests
(
SiteMixin
,
TestCase
):
"""
"""
Tests email functionality for course role assignment changed.
Tests email functionality for course role assignment changed.
"""
"""
...
@@ -575,7 +575,7 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
...
@@ -575,7 +575,7 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
"""
"""
Verify that course role assignment chnage email functionality works fine.
Verify that course role assignment chnage email functionality works fine.
"""
"""
emails
.
send_change_role_assignment_email
(
self
.
marketing_role
,
self
.
user
)
emails
.
send_change_role_assignment_email
(
self
.
marketing_role
,
self
.
user
,
self
.
site
)
expected_subject
=
'{role_name} changed for {course_title}'
.
format
(
expected_subject
=
'{role_name} changed for {course_title}'
.
format
(
role_name
=
self
.
marketing_role
.
get_role_display
()
.
lower
(),
role_name
=
self
.
marketing_role
.
get_role_display
()
.
lower
(),
course_title
=
self
.
course
.
title
course_title
=
self
.
course
.
title
...
@@ -589,7 +589,7 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
...
@@ -589,7 +589,7 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
page_path
=
reverse
(
'publisher:publisher_course_detail'
,
kwargs
=
{
'pk'
:
self
.
course
.
id
})
page_path
=
reverse
(
'publisher:publisher_course_detail'
,
kwargs
=
{
'pk'
:
self
.
course
.
id
})
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
'has changed.'
,
body
)
self
.
assertIn
(
'has changed.'
,
body
)
...
@@ -603,11 +603,11 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
...
@@ -603,11 +603,11 @@ class CourseChangeRoleAssignmentEmailTests(TestCase):
)
)
with
mock
.
patch
(
'django.core.mail.message.EmailMessage.send'
,
side_effect
=
TypeError
):
with
mock
.
patch
(
'django.core.mail.message.EmailMessage.send'
,
side_effect
=
TypeError
):
with
self
.
assertRaises
(
Exception
)
as
ex
:
with
self
.
assertRaises
(
Exception
)
as
ex
:
emails
.
send_change_role_assignment_email
(
self
.
marketing_role
,
self
.
user
)
emails
.
send_change_role_assignment_email
(
self
.
marketing_role
,
self
.
user
,
self
.
site
)
self
.
assertEqual
(
str
(
ex
.
exception
),
message
)
self
.
assertEqual
(
str
(
ex
.
exception
),
message
)
class
SEOReviewEmailTests
(
TestCase
):
class
SEOReviewEmailTests
(
SiteMixin
,
TestCase
):
""" Tests for the seo review email functionality. """
""" Tests for the seo review email functionality. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -626,7 +626,7 @@ class SEOReviewEmailTests(TestCase):
...
@@ -626,7 +626,7 @@ class SEOReviewEmailTests(TestCase):
""" Verify that email failure logs error message."""
""" Verify that email failure logs error message."""
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_for_seo_review
(
self
.
course
)
emails
.
send_email_for_seo_review
(
self
.
course
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
@@ -642,7 +642,7 @@ class SEOReviewEmailTests(TestCase):
...
@@ -642,7 +642,7 @@ class SEOReviewEmailTests(TestCase):
Verify that seo review email functionality works fine.
Verify that seo review email functionality works fine.
"""
"""
factories
.
CourseUserRoleFactory
(
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
)
factories
.
CourseUserRoleFactory
(
course
=
self
.
course
,
role
=
PublisherUserRole
.
ProjectCoordinator
)
emails
.
send_email_for_seo_review
(
self
.
course
)
emails
.
send_email_for_seo_review
(
self
.
course
,
self
.
site
)
expected_subject
=
'Legal review requested: {title}'
.
format
(
title
=
self
.
course
.
title
)
expected_subject
=
'Legal review requested: {title}'
.
format
(
title
=
self
.
course
.
title
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
1
)
self
.
assertEqual
(
len
(
mail
.
outbox
),
1
)
...
@@ -652,7 +652,7 @@ class SEOReviewEmailTests(TestCase):
...
@@ -652,7 +652,7 @@ class SEOReviewEmailTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
page_path
=
reverse
(
'publisher:publisher_course_detail'
,
kwargs
=
{
'pk'
:
self
.
course
.
id
})
page_path
=
reverse
(
'publisher:publisher_course_detail'
,
kwargs
=
{
'pk'
:
self
.
course
.
id
})
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
page_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
page_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
'determine OFAC status'
,
body
)
self
.
assertIn
(
'determine OFAC status'
,
body
)
...
@@ -671,7 +671,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
...
@@ -671,7 +671,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
)
)
self
.
course_run
.
lms_course_id
=
'course-v1:testX+test45+2017T2'
self
.
course_run
.
lms_course_id
=
'course-v1:testX+test45+2017T2'
self
.
course_run
.
save
()
self
.
course_run
.
save
()
emails
.
send_email_for_published_course_run_editing
(
self
.
course_run
)
emails
.
send_email_for_published_course_run_editing
(
self
.
course_run
,
self
.
site
)
course_key
=
CourseKey
.
from_string
(
self
.
course_run
.
lms_course_id
)
course_key
=
CourseKey
.
from_string
(
self
.
course_run
.
lms_course_id
)
...
@@ -692,7 +692,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
...
@@ -692,7 +692,7 @@ class CourseRunPublishedEditEmailTests(CourseRunPublishedEmailTests):
""" Verify that email failure logs error message."""
""" Verify that email failure logs error message."""
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
with
LogCapture
(
emails
.
logger
.
name
)
as
l
:
emails
.
send_email_for_published_course_run_editing
(
self
.
course_run
)
emails
.
send_email_for_published_course_run_editing
(
self
.
course_run
,
self
.
site
)
l
.
check
(
l
.
check
(
(
(
emails
.
logger
.
name
,
emails
.
logger
.
name
,
...
...
course_discovery/apps/publisher/tests/test_model.py
View file @
e251012e
...
@@ -6,7 +6,7 @@ from django.urls import reverse
...
@@ -6,7 +6,7 @@ from django.urls import reverse
from
django_fsm
import
TransitionNotAllowed
from
django_fsm
import
TransitionNotAllowed
from
guardian.shortcuts
import
assign_perm
from
guardian.shortcuts
import
assign_perm
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
PartnerFactory
,
SiteFactory
,
UserFactory
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.course_metadata.tests.factories
import
OrganizationFactory
,
PersonFactory
from
course_discovery.apps.course_metadata.tests.factories
import
OrganizationFactory
,
PersonFactory
from
course_discovery.apps.ietf_language_tags.models
import
LanguageTag
from
course_discovery.apps.ietf_language_tags.models
import
LanguageTag
...
@@ -525,6 +525,8 @@ class CourseStateTests(TestCase):
...
@@ -525,6 +525,8 @@ class CourseStateTests(TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CourseStateTests
,
self
)
.
setUp
()
super
(
CourseStateTests
,
self
)
.
setUp
()
self
.
site
=
SiteFactory
()
self
.
partner
=
PartnerFactory
(
site
=
self
.
site
)
self
.
course
=
self
.
course_state
.
course
self
.
course
=
self
.
course_state
.
course
self
.
course
.
image
=
make_image_file
(
'test_banner.jpg'
)
self
.
course
.
image
=
make_image_file
(
'test_banner.jpg'
)
self
.
course
.
save
()
self
.
course
.
save
()
...
@@ -548,7 +550,7 @@ class CourseStateTests(TestCase):
...
@@ -548,7 +550,7 @@ class CourseStateTests(TestCase):
"""
"""
self
.
assertNotEqual
(
self
.
course_state
.
name
,
state
)
self
.
assertNotEqual
(
self
.
course_state
.
name
,
state
)
self
.
course_state
.
change_state
(
state
=
state
,
user
=
self
.
user
)
self
.
course_state
.
change_state
(
state
=
state
,
user
=
self
.
user
,
site
=
self
.
site
)
self
.
assertEqual
(
self
.
course_state
.
name
,
state
)
self
.
assertEqual
(
self
.
course_state
.
name
,
state
)
...
@@ -561,7 +563,7 @@ class CourseStateTests(TestCase):
...
@@ -561,7 +563,7 @@ class CourseStateTests(TestCase):
self
.
assertEqual
(
self
.
course_state
.
name
,
CourseStateChoices
.
Draft
)
self
.
assertEqual
(
self
.
course_state
.
name
,
CourseStateChoices
.
Draft
)
with
self
.
assertRaises
(
TransitionNotAllowed
):
with
self
.
assertRaises
(
TransitionNotAllowed
):
self
.
course_state
.
change_state
(
state
=
CourseStateChoices
.
Review
,
user
=
self
.
user
)
self
.
course_state
.
change_state
(
state
=
CourseStateChoices
.
Review
,
user
=
self
.
user
,
site
=
self
.
site
)
def
test_can_send_for_review
(
self
):
def
test_can_send_for_review
(
self
):
"""
"""
...
@@ -673,6 +675,9 @@ class CourseRunStateTests(TestCase):
...
@@ -673,6 +675,9 @@ class CourseRunStateTests(TestCase):
language_tag
=
LanguageTag
(
code
=
'te-st'
,
name
=
'Test Language'
)
language_tag
=
LanguageTag
(
code
=
'te-st'
,
name
=
'Test Language'
)
language_tag
.
save
()
language_tag
.
save
()
self
.
site
=
SiteFactory
()
self
.
partner
=
PartnerFactory
(
site
=
self
.
site
)
self
.
course_run
.
transcript_languages
.
add
(
language_tag
)
self
.
course_run
.
transcript_languages
.
add
(
language_tag
)
self
.
course_run
.
language
=
language_tag
self
.
course_run
.
language
=
language_tag
self
.
course_run
.
is_micromasters
=
True
self
.
course_run
.
is_micromasters
=
True
...
@@ -703,7 +708,7 @@ class CourseRunStateTests(TestCase):
...
@@ -703,7 +708,7 @@ class CourseRunStateTests(TestCase):
Verify that we can change course-run state according to workflow.
Verify that we can change course-run state according to workflow.
"""
"""
self
.
assertNotEqual
(
self
.
course_run_state
.
name
,
state
)
self
.
assertNotEqual
(
self
.
course_run_state
.
name
,
state
)
self
.
course_run_state
.
change_state
(
state
=
state
,
user
=
self
.
user
)
self
.
course_run_state
.
change_state
(
state
=
state
,
user
=
self
.
user
,
site
=
self
.
site
)
self
.
assertEqual
(
self
.
course_run_state
.
name
,
state
)
self
.
assertEqual
(
self
.
course_run_state
.
name
,
state
)
def
test_with_invalid_parent_course_state
(
self
):
def
test_with_invalid_parent_course_state
(
self
):
...
...
course_discovery/apps/publisher/tests/test_views.py
View file @
e251012e
...
@@ -19,11 +19,13 @@ from opaque_keys.edx.keys import CourseKey
...
@@ -19,11 +19,13 @@ from opaque_keys.edx.keys import CourseKey
from
pytz
import
timezone
from
pytz
import
timezone
from
testfixtures
import
LogCapture
from
testfixtures
import
LogCapture
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests.factories
import
CourseFactory
,
OrganizationFactory
,
PersonFactory
from
course_discovery.apps.course_metadata.tests.factories
import
(
CourseFactory
,
OrganizationFactory
,
PersonFactory
,
SubjectFactory
)
from
course_discovery.apps.ietf_language_tags.models
import
LanguageTag
from
course_discovery.apps.ietf_language_tags.models
import
LanguageTag
from
course_discovery.apps.publisher.choices
import
(
CourseRunStateChoices
,
CourseStateChoices
,
InternalUserRole
,
from
course_discovery.apps.publisher.choices
import
(
CourseRunStateChoices
,
CourseStateChoices
,
InternalUserRole
,
PublisherUserRole
)
PublisherUserRole
)
...
@@ -42,7 +44,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
...
@@ -42,7 +44,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt
@ddt.ddt
class
CreateCourseViewTests
(
TestCase
):
class
CreateCourseViewTests
(
SiteMixin
,
TestCase
):
""" Tests for the publisher `CreateCourseView`. """
""" Tests for the publisher `CreateCourseView`. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -61,7 +63,6 @@ class CreateCourseViewTests(TestCase):
...
@@ -61,7 +63,6 @@ class CreateCourseViewTests(TestCase):
self
.
course
=
factories
.
CourseFactory
()
self
.
course
=
factories
.
CourseFactory
()
self
.
course
.
organizations
.
add
(
self
.
organization_extension
.
organization
)
self
.
course
.
organizations
.
add
(
self
.
organization_extension
.
organization
)
self
.
site
=
Site
.
objects
.
get
(
pk
=
settings
.
SITE_ID
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
# creating default organizations roles
# creating default organizations roles
...
@@ -269,7 +270,7 @@ class CreateCourseViewTests(TestCase):
...
@@ -269,7 +270,7 @@ class CreateCourseViewTests(TestCase):
)
)
class
CreateCourseRunViewTests
(
TestCase
):
class
CreateCourseRunViewTests
(
SiteMixin
,
TestCase
):
""" Tests for the publisher `UpdateCourseRunView`. """
""" Tests for the publisher `UpdateCourseRunView`. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -299,7 +300,6 @@ class CreateCourseRunViewTests(TestCase):
...
@@ -299,7 +300,6 @@ class CreateCourseRunViewTests(TestCase):
current_datetime
=
datetime
.
now
(
timezone
(
'US/Central'
))
current_datetime
=
datetime
.
now
(
timezone
(
'US/Central'
))
self
.
course_run_dict
[
'start'
]
=
(
current_datetime
+
timedelta
(
days
=
1
))
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
)
self
.
course_run_dict
[
'start'
]
=
(
current_datetime
+
timedelta
(
days
=
1
))
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
)
self
.
course_run_dict
[
'end'
]
=
(
current_datetime
+
timedelta
(
days
=
3
))
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
)
self
.
course_run_dict
[
'end'
]
=
(
current_datetime
+
timedelta
(
days
=
3
))
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
)
self
.
site
=
Site
.
objects
.
get
(
pk
=
settings
.
SITE_ID
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
def
_pop_valuse_from_dict
(
self
,
data_dict
,
key_list
):
def
_pop_valuse_from_dict
(
self
,
data_dict
,
key_list
):
...
@@ -562,7 +562,7 @@ class CreateCourseRunViewTests(TestCase):
...
@@ -562,7 +562,7 @@ class CreateCourseRunViewTests(TestCase):
@ddt.ddt
@ddt.ddt
class
CourseRunDetailTests
(
TestCase
):
class
CourseRunDetailTests
(
SiteMixin
,
TestCase
):
""" Tests for the course-run detail view. """
""" Tests for the course-run detail view. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -763,9 +763,8 @@ class CourseRunDetailTests(TestCase):
...
@@ -763,9 +763,8 @@ class CourseRunDetailTests(TestCase):
"""
"""
self
.
client
.
logout
()
self
.
client
.
logout
()
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
site
=
Site
.
objects
.
get
(
pk
=
settings
.
SITE_ID
)
comment
=
CommentFactory
(
content_object
=
self
.
course_run
,
user
=
self
.
user
,
site
=
site
)
comment
=
CommentFactory
(
content_object
=
self
.
course_run
,
user
=
self
.
user
,
site
=
s
elf
.
s
ite
)
response
=
self
.
client
.
get
(
self
.
page_url
)
response
=
self
.
client
.
get
(
self
.
page_url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
_assert_credits_seats
(
response
,
self
.
wrapped_course_run
.
credit_seat
)
self
.
_assert_credits_seats
(
response
,
self
.
wrapped_course_run
.
credit_seat
)
...
@@ -779,7 +778,7 @@ class CourseRunDetailTests(TestCase):
...
@@ -779,7 +778,7 @@ class CourseRunDetailTests(TestCase):
# test decline comment appearing on detail page also.
# test decline comment appearing on detail page also.
decline_comment
=
CommentFactory
(
decline_comment
=
CommentFactory
(
content_object
=
self
.
course_run
,
content_object
=
self
.
course_run
,
user
=
self
.
user
,
site
=
site
,
comment_type
=
CommentTypeChoices
.
Decline_Preview
user
=
self
.
user
,
site
=
s
elf
.
s
ite
,
comment_type
=
CommentTypeChoices
.
Decline_Preview
)
)
response
=
self
.
client
.
get
(
self
.
page_url
)
response
=
self
.
client
.
get
(
self
.
page_url
)
self
.
assertContains
(
response
,
decline_comment
.
comment
)
self
.
assertContains
(
response
,
decline_comment
.
comment
)
...
@@ -1233,12 +1232,12 @@ class CourseRunDetailTests(TestCase):
...
@@ -1233,12 +1232,12 @@ class CourseRunDetailTests(TestCase):
# pylint: disable=attribute-defined-outside-init
# pylint: disable=attribute-defined-outside-init
@ddt.ddt
@ddt.ddt
class
DashboardTests
(
TestCase
):
class
DashboardTests
(
SiteMixin
,
TestCase
):
""" Tests for the `Dashboard`. """
""" Tests for the `Dashboard`. """
def
setUp
(
self
):
def
setUp
(
self
):
super
(
DashboardTests
,
self
)
.
setUp
()
super
(
DashboardTests
,
self
)
.
setUp
()
Site
.
objects
.
exclude
(
id
=
self
.
site
.
id
)
.
delete
()
self
.
group_internal
=
Group
.
objects
.
get
(
name
=
INTERNAL_USER_GROUP_NAME
)
self
.
group_internal
=
Group
.
objects
.
get
(
name
=
INTERNAL_USER_GROUP_NAME
)
self
.
group_project_coordinator
=
Group
.
objects
.
get
(
name
=
PROJECT_COORDINATOR_GROUP_NAME
)
self
.
group_project_coordinator
=
Group
.
objects
.
get
(
name
=
PROJECT_COORDINATOR_GROUP_NAME
)
self
.
group_reviewer
=
Group
.
objects
.
get
(
name
=
REVIEWER_GROUP_NAME
)
self
.
group_reviewer
=
Group
.
objects
.
get
(
name
=
REVIEWER_GROUP_NAME
)
...
@@ -1277,7 +1276,18 @@ class DashboardTests(TestCase):
...
@@ -1277,7 +1276,18 @@ class DashboardTests(TestCase):
def
_create_course_assign_role
(
self
,
state
,
user
,
role
):
def
_create_course_assign_role
(
self
,
state
,
user
,
role
):
""" Create course-run-state, course-user-role and return course-run. """
""" Create course-run-state, course-user-role and return course-run. """
course_run_state
=
factories
.
CourseRunStateFactory
(
name
=
state
,
owner_role
=
role
)
course
=
factories
.
CourseFactory
(
primary_subject
=
SubjectFactory
(
partner
=
self
.
partner
),
secondary_subject
=
SubjectFactory
(
partner
=
self
.
partner
),
tertiary_subject
=
SubjectFactory
(
partner
=
self
.
partner
)
)
course_run
=
factories
.
CourseRunFactory
(
course
=
course
)
course_run_state
=
factories
.
CourseRunStateFactory
(
name
=
state
,
owner_role
=
role
,
course_run
=
course_run
)
factories
.
CourseUserRoleFactory
(
course
=
course_run_state
.
course_run
.
course
,
role
=
role
,
user
=
user
)
factories
.
CourseUserRoleFactory
(
course
=
course_run_state
.
course_run
.
course
,
role
=
role
,
user
=
user
)
return
course_run_state
.
course_run
return
course_run_state
.
course_run
...
@@ -1301,7 +1311,7 @@ class DashboardTests(TestCase):
...
@@ -1301,7 +1311,7 @@ class DashboardTests(TestCase):
self
.
client
.
logout
()
self
.
client
.
logout
()
self
.
client
.
login
(
username
=
UserFactory
(),
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
UserFactory
(),
password
=
USER_PASSWORD
)
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
0
,
published_count
=
0
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
1
studio_count
=
0
,
published_count
=
0
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
2
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1309,7 +1319,7 @@ class DashboardTests(TestCase):
...
@@ -1309,7 +1319,7 @@ class DashboardTests(TestCase):
def
test_with_internal_group
(
self
,
tab
):
def
test_with_internal_group
(
self
,
tab
):
""" Verify that internal user can see courses assigned to the groups. """
""" Verify that internal user can see courses assigned to the groups. """
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
3
studio_count
=
2
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
4
)
)
self
.
assertContains
(
response
,
'<li role="tab" id="tab-{tab}" class="tab"'
.
format
(
tab
=
tab
))
self
.
assertContains
(
response
,
'<li role="tab" id="tab-{tab}" class="tab"'
.
format
(
tab
=
tab
))
...
@@ -1324,7 +1334,7 @@ class DashboardTests(TestCase):
...
@@ -1324,7 +1334,7 @@ class DashboardTests(TestCase):
self
.
course_run_1
.
course
.
organizations
.
add
(
self
.
organization_extension
.
organization
)
self
.
course_run_1
.
course
.
organizations
.
add
(
self
.
organization_extension
.
organization
)
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
0
,
published_count
=
0
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
1
studio_count
=
0
,
published_count
=
0
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
2
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1349,14 +1359,14 @@ class DashboardTests(TestCase):
...
@@ -1349,14 +1359,14 @@ class DashboardTests(TestCase):
)
)
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
0
,
published_count
=
0
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
1
studio_count
=
0
,
published_count
=
0
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
2
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
def
test_studio_request_course_runs_as_pc
(
self
):
def
test_studio_request_course_runs_as_pc
(
self
):
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
3
studio_count
=
2
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
4
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1364,7 +1374,7 @@ class DashboardTests(TestCase):
...
@@ -1364,7 +1374,7 @@ class DashboardTests(TestCase):
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
""" Verify that PC user can see only those courses on which he is assigned as PC role. """
self
.
user1
.
groups
.
remove
(
self
.
group_project_coordinator
)
self
.
user1
.
groups
.
remove
(
self
.
group_project_coordinator
)
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
0
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
0
studio_count
=
0
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
1
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1375,7 +1385,7 @@ class DashboardTests(TestCase):
...
@@ -1375,7 +1385,7 @@ class DashboardTests(TestCase):
self
.
course_run_2
.
lms_course_id
=
'test-2'
self
.
course_run_2
.
lms_course_id
=
'test-2'
self
.
course_run_2
.
save
()
self
.
course_run_2
.
save
()
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
0
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
1
studio_count
=
0
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
2
)
)
self
.
assertContains
(
response
,
'No courses are currently ready for a Studio URL.'
)
self
.
assertContains
(
response
,
'No courses are currently ready for a Studio URL.'
)
...
@@ -1384,7 +1394,7 @@ class DashboardTests(TestCase):
...
@@ -1384,7 +1394,7 @@ class DashboardTests(TestCase):
self
.
course_run_3
.
course_run_state
.
name
=
CourseRunStateChoices
.
Draft
self
.
course_run_3
.
course_run_state
.
name
=
CourseRunStateChoices
.
Draft
self
.
course_run_3
.
course_run_state
.
save
()
self
.
course_run_3
.
course_run_state
.
save
()
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
3
,
published_count
=
0
,
progress_count
=
3
,
preview_count
=
1
,
queries_executed
=
2
4
studio_count
=
3
,
published_count
=
0
,
progress_count
=
3
,
preview_count
=
1
,
queries_executed
=
2
5
)
)
self
.
assertContains
(
response
,
'No About pages have been published yet'
)
self
.
assertContains
(
response
,
'No About pages have been published yet'
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1392,7 +1402,7 @@ class DashboardTests(TestCase):
...
@@ -1392,7 +1402,7 @@ class DashboardTests(TestCase):
def
test_published_course_runs
(
self
):
def
test_published_course_runs
(
self
):
""" Verify that published tab loads course runs list. """
""" Verify that published tab loads course runs list. """
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
3
studio_count
=
2
,
published_count
=
1
,
progress_count
=
2
,
preview_count
=
1
,
queries_executed
=
2
4
)
)
self
.
assertContains
(
response
,
self
.
table_class
.
format
(
id
=
'published'
))
self
.
assertContains
(
response
,
self
.
table_class
.
format
(
id
=
'published'
))
self
.
assertContains
(
response
,
'About pages for the following course runs have been published in the'
)
self
.
assertContains
(
response
,
'About pages for the following course runs have been published in the'
)
...
@@ -1410,7 +1420,7 @@ class DashboardTests(TestCase):
...
@@ -1410,7 +1420,7 @@ class DashboardTests(TestCase):
# Verify that user cannot see any published course run
# Verify that user cannot see any published course run
self
.
assert_dashboard_response
(
self
.
assert_dashboard_response
(
studio_count
=
0
,
published_count
=
0
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
5
studio_count
=
0
,
published_count
=
0
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
6
)
)
# assign user course role
# assign user course role
...
@@ -1420,7 +1430,7 @@ class DashboardTests(TestCase):
...
@@ -1420,7 +1430,7 @@ class DashboardTests(TestCase):
# Verify that user can see 1 published course run
# Verify that user can see 1 published course run
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
0
,
published_count
=
1
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
6
studio_count
=
0
,
published_count
=
1
,
progress_count
=
0
,
preview_count
=
0
,
queries_executed
=
1
7
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1434,14 +1444,14 @@ class DashboardTests(TestCase):
...
@@ -1434,14 +1444,14 @@ class DashboardTests(TestCase):
publisher_admin
.
groups
.
add
(
Group
.
objects
.
get
(
name
=
ADMIN_GROUP_NAME
))
publisher_admin
.
groups
.
add
(
Group
.
objects
.
get
(
name
=
ADMIN_GROUP_NAME
))
self
.
client
.
login
(
username
=
publisher_admin
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
publisher_admin
.
username
,
password
=
USER_PASSWORD
)
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
4
,
published_count
=
1
,
progress_count
=
3
,
preview_count
=
1
,
queries_executed
=
2
0
studio_count
=
4
,
published_count
=
1
,
progress_count
=
3
,
preview_count
=
1
,
queries_executed
=
2
1
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
def
test_with_preview_ready_course_runs
(
self
):
def
test_with_preview_ready_course_runs
(
self
):
""" Verify that preview ready tabs loads the course runs list. """
""" Verify that preview ready tabs loads the course runs list. """
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
3
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
4
)
)
self
.
assertContains
(
response
,
self
.
table_class
.
format
(
id
=
'preview'
))
self
.
assertContains
(
response
,
self
.
table_class
.
format
(
id
=
'preview'
))
self
.
assertContains
(
response
,
'About page previews for the following course runs are available for course team'
)
self
.
assertContains
(
response
,
'About page previews for the following course runs are available for course team'
)
...
@@ -1453,7 +1463,7 @@ class DashboardTests(TestCase):
...
@@ -1453,7 +1463,7 @@ class DashboardTests(TestCase):
self
.
course_run_2
.
course_run_state
.
name
=
CourseRunStateChoices
.
Draft
self
.
course_run_2
.
course_run_state
.
name
=
CourseRunStateChoices
.
Draft
self
.
course_run_2
.
course_run_state
.
save
()
self
.
course_run_2
.
course_run_state
.
save
()
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
preview_count
=
0
,
progress_count
=
3
,
published_count
=
1
,
queries_executed
=
2
2
studio_count
=
2
,
preview_count
=
0
,
progress_count
=
3
,
published_count
=
1
,
queries_executed
=
2
3
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1462,7 +1472,7 @@ class DashboardTests(TestCase):
...
@@ -1462,7 +1472,7 @@ class DashboardTests(TestCase):
preview url is added or not.
preview url is added or not.
"""
"""
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
3
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
4
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1470,14 +1480,14 @@ class DashboardTests(TestCase):
...
@@ -1470,14 +1480,14 @@ class DashboardTests(TestCase):
self
.
course_run_2
.
preview_url
=
None
self
.
course_run_2
.
preview_url
=
None
self
.
course_run_2
.
save
()
self
.
course_run_2
.
save
()
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
3
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
4
)
)
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
def
test_with_in_progress_course_runs
(
self
):
def
test_with_in_progress_course_runs
(
self
):
""" Verify that in progress tabs loads the course runs list. """
""" Verify that in progress tabs loads the course runs list. """
response
=
self
.
assert_dashboard_response
(
response
=
self
.
assert_dashboard_response
(
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
3
studio_count
=
2
,
preview_count
=
1
,
progress_count
=
2
,
published_count
=
1
,
queries_executed
=
2
4
)
)
self
.
assertContains
(
response
,
self
.
table_class
.
format
(
id
=
'in-progress'
))
self
.
assertContains
(
response
,
self
.
table_class
.
format
(
id
=
'in-progress'
))
self
.
_assert_tabs_with_roles
(
response
)
self
.
_assert_tabs_with_roles
(
response
)
...
@@ -1513,7 +1523,7 @@ class DashboardTests(TestCase):
...
@@ -1513,7 +1523,7 @@ class DashboardTests(TestCase):
self
.
client
.
logout
()
self
.
client
.
logout
()
self
.
client
.
login
(
username
=
pc_user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
pc_user
.
username
,
password
=
USER_PASSWORD
)
with
self
.
assertNumQueries
(
1
1
):
with
self
.
assertNumQueries
(
1
2
):
response
=
self
.
client
.
get
(
self
.
page_url
)
response
=
self
.
client
.
get
(
self
.
page_url
)
for
tab
in
[
'progress'
,
'preview'
,
'studio'
,
'published'
]:
for
tab
in
[
'progress'
,
'preview'
,
'studio'
,
'published'
]:
...
@@ -1523,7 +1533,7 @@ class DashboardTests(TestCase):
...
@@ -1523,7 +1533,7 @@ class DashboardTests(TestCase):
"""
"""
Verify that site_name is available in context.
Verify that site_name is available in context.
"""
"""
with
self
.
assertNumQueries
(
2
3
):
with
self
.
assertNumQueries
(
2
4
):
response
=
self
.
client
.
get
(
self
.
page_url
)
response
=
self
.
client
.
get
(
self
.
page_url
)
site
=
Site
.
objects
.
first
()
site
=
Site
.
objects
.
first
()
self
.
assertEqual
(
response
.
context
[
'site_name'
],
site
.
name
)
self
.
assertEqual
(
response
.
context
[
'site_name'
],
site
.
name
)
...
@@ -1542,13 +1552,12 @@ class DashboardTests(TestCase):
...
@@ -1542,13 +1552,12 @@ class DashboardTests(TestCase):
course_run
.
course_run_state
.
owner_role
=
PublisherUserRole
.
CourseTeam
course_run
.
course_run_state
.
owner_role
=
PublisherUserRole
.
CourseTeam
course_run
.
course_run_state
.
save
()
course_run
.
course_run_state
.
save
()
with
self
.
assertNumQueries
(
2
5
):
with
self
.
assertNumQueries
(
2
6
):
response
=
self
.
client
.
get
(
self
.
page_url
)
response
=
self
.
client
.
get
(
self
.
page_url
)
site
=
Site
.
objects
.
first
()
self
.
_assert_filter_counts
(
response
,
'All'
,
3
)
self
.
_assert_filter_counts
(
response
,
'All'
,
3
)
self
.
_assert_filter_counts
(
response
,
'With Course Team'
,
2
)
self
.
_assert_filter_counts
(
response
,
'With Course Team'
,
2
)
self
.
_assert_filter_counts
(
response
,
'With {site_name}'
.
format
(
site_name
=
site
.
name
),
1
)
self
.
_assert_filter_counts
(
response
,
'With {site_name}'
.
format
(
site_name
=
s
elf
.
s
ite
.
name
),
1
)
def
_assert_filter_counts
(
self
,
response
,
expected_label
,
count
):
def
_assert_filter_counts
(
self
,
response
,
expected_label
,
count
):
"""
"""
...
@@ -1559,7 +1568,7 @@ class DashboardTests(TestCase):
...
@@ -1559,7 +1568,7 @@ class DashboardTests(TestCase):
self
.
assertContains
(
response
,
expected_count
,
count
=
1
)
self
.
assertContains
(
response
,
expected_count
,
count
=
1
)
class
ToggleEmailNotificationTests
(
TestCase
):
class
ToggleEmailNotificationTests
(
SiteMixin
,
TestCase
):
""" Tests for `ToggleEmailNotification` view. """
""" Tests for `ToggleEmailNotification` view. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -1592,7 +1601,7 @@ class ToggleEmailNotificationTests(TestCase):
...
@@ -1592,7 +1601,7 @@ class ToggleEmailNotificationTests(TestCase):
self
.
assertEqual
(
is_email_notification_enabled
(
user
),
is_enabled
)
self
.
assertEqual
(
is_email_notification_enabled
(
user
),
is_enabled
)
class
CourseListViewTests
(
TestCase
):
class
CourseListViewTests
(
SiteMixin
,
TestCase
):
""" Tests for `CourseListView` """
""" Tests for `CourseListView` """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -1606,12 +1615,12 @@ class CourseListViewTests(TestCase):
...
@@ -1606,12 +1615,12 @@ class CourseListViewTests(TestCase):
def
test_courses_with_no_courses
(
self
):
def
test_courses_with_no_courses
(
self
):
""" Verify that user cannot see any course on course list page. """
""" Verify that user cannot see any course on course list page. """
self
.
assert_course_list_page
(
course_count
=
0
,
queries_executed
=
8
)
self
.
assert_course_list_page
(
course_count
=
0
,
queries_executed
=
9
)
def
test_courses_with_admin
(
self
):
def
test_courses_with_admin
(
self
):
""" Verify that admin user can see all courses on course list page. """
""" Verify that admin user can see all courses on course list page. """
self
.
user
.
groups
.
add
(
Group
.
objects
.
get
(
name
=
ADMIN_GROUP_NAME
))
self
.
user
.
groups
.
add
(
Group
.
objects
.
get
(
name
=
ADMIN_GROUP_NAME
))
self
.
assert_course_list_page
(
course_count
=
10
,
queries_executed
=
3
1
)
self
.
assert_course_list_page
(
course_count
=
10
,
queries_executed
=
3
2
)
def
test_courses_with_course_user_role
(
self
):
def
test_courses_with_course_user_role
(
self
):
""" Verify that internal user can see course on course list page. """
""" Verify that internal user can see course on course list page. """
...
@@ -1619,7 +1628,7 @@ class CourseListViewTests(TestCase):
...
@@ -1619,7 +1628,7 @@ class CourseListViewTests(TestCase):
for
course
in
self
.
courses
:
for
course
in
self
.
courses
:
factories
.
CourseUserRoleFactory
(
course
=
course
,
user
=
self
.
user
,
role
=
InternalUserRole
.
Publisher
)
factories
.
CourseUserRoleFactory
(
course
=
course
,
user
=
self
.
user
,
role
=
InternalUserRole
.
Publisher
)
self
.
assert_course_list_page
(
course_count
=
10
,
queries_executed
=
3
2
)
self
.
assert_course_list_page
(
course_count
=
10
,
queries_executed
=
3
3
)
def
test_courses_with_permission
(
self
):
def
test_courses_with_permission
(
self
):
""" Verify that user can see course with permission on course list page. """
""" Verify that user can see course with permission on course list page. """
...
@@ -1630,7 +1639,7 @@ class CourseListViewTests(TestCase):
...
@@ -1630,7 +1639,7 @@ class CourseListViewTests(TestCase):
course
.
organizations
.
add
(
organization_extension
.
organization
)
course
.
organizations
.
add
(
organization_extension
.
organization
)
assign_perm
(
OrganizationExtension
.
VIEW_COURSE
,
organization_extension
.
group
,
organization_extension
)
assign_perm
(
OrganizationExtension
.
VIEW_COURSE
,
organization_extension
.
group
,
organization_extension
)
self
.
assert_course_list_page
(
course_count
=
10
,
queries_executed
=
6
4
)
self
.
assert_course_list_page
(
course_count
=
10
,
queries_executed
=
6
5
)
def
assert_course_list_page
(
self
,
course_count
,
queries_executed
):
def
assert_course_list_page
(
self
,
course_count
,
queries_executed
):
""" Dry method to assert course list page content. """
""" Dry method to assert course list page content. """
...
@@ -1657,7 +1666,7 @@ class CourseListViewTests(TestCase):
...
@@ -1657,7 +1666,7 @@ class CourseListViewTests(TestCase):
toggle_switch
(
'publisher_hide_features_for_pilot'
,
True
)
toggle_switch
(
'publisher_hide_features_for_pilot'
,
True
)
with
self
.
assertNumQueries
(
1
7
):
with
self
.
assertNumQueries
(
1
8
):
response
=
self
.
client
.
get
(
self
.
courses_url
)
response
=
self
.
client
.
get
(
self
.
courses_url
)
self
.
assertNotContains
(
response
,
'Edit'
)
self
.
assertNotContains
(
response
,
'Edit'
)
...
@@ -1676,13 +1685,13 @@ class CourseListViewTests(TestCase):
...
@@ -1676,13 +1685,13 @@ class CourseListViewTests(TestCase):
toggle_switch
(
'publisher_hide_features_for_pilot'
,
False
)
toggle_switch
(
'publisher_hide_features_for_pilot'
,
False
)
with
self
.
assertNumQueries
(
2
1
):
with
self
.
assertNumQueries
(
2
2
):
response
=
self
.
client
.
get
(
self
.
courses_url
)
response
=
self
.
client
.
get
(
self
.
courses_url
)
self
.
assertContains
(
response
,
'Edit'
)
self
.
assertContains
(
response
,
'Edit'
)
class
CourseDetailViewTests
(
TestCase
):
class
CourseDetailViewTests
(
SiteMixin
,
TestCase
):
""" Tests for the course detail view. """
""" Tests for the course detail view. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -2114,7 +2123,7 @@ class CourseDetailViewTests(TestCase):
...
@@ -2114,7 +2123,7 @@ class CourseDetailViewTests(TestCase):
@ddt.ddt
@ddt.ddt
class
CourseEditViewTests
(
TestCase
):
class
CourseEditViewTests
(
SiteMixin
,
TestCase
):
""" Tests for the course edit view. """
""" Tests for the course edit view. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -2532,7 +2541,7 @@ class CourseEditViewTests(TestCase):
...
@@ -2532,7 +2541,7 @@ class CourseEditViewTests(TestCase):
@ddt.ddt
@ddt.ddt
class
CourseRunEditViewTests
(
TestCase
):
class
CourseRunEditViewTests
(
SiteMixin
,
TestCase
):
""" Tests for the course run edit view. """
""" Tests for the course run edit view. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -2550,7 +2559,6 @@ class CourseRunEditViewTests(TestCase):
...
@@ -2550,7 +2559,6 @@ class CourseRunEditViewTests(TestCase):
self
.
seat
=
factories
.
SeatFactory
(
course_run
=
self
.
course_run
,
type
=
Seat
.
VERIFIED
,
price
=
2
)
self
.
seat
=
factories
.
SeatFactory
(
course_run
=
self
.
course_run
,
type
=
Seat
.
VERIFIED
,
price
=
2
)
self
.
course
.
organizations
.
add
(
self
.
organization_extension
.
organization
)
self
.
course
.
organizations
.
add
(
self
.
organization_extension
.
organization
)
self
.
site
=
Site
.
objects
.
get
(
pk
=
settings
.
SITE_ID
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
current_datetime
=
datetime
.
now
(
timezone
(
'US/Central'
))
current_datetime
=
datetime
.
now
(
timezone
(
'US/Central'
))
self
.
start_date_time
=
(
current_datetime
+
timedelta
(
days
=
1
))
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
)
self
.
start_date_time
=
(
current_datetime
+
timedelta
(
days
=
1
))
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
)
...
@@ -2833,7 +2841,7 @@ class CourseRunEditViewTests(TestCase):
...
@@ -2833,7 +2841,7 @@ class CourseRunEditViewTests(TestCase):
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
body
=
mail
.
outbox
[
0
]
.
body
.
strip
()
self
.
assertIn
(
expected_body
,
body
)
self
.
assertIn
(
expected_body
,
body
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
Site
.
objects
.
get_current
()
.
domain
.
strip
(
'/'
),
path
=
object_path
)
page_url
=
'https://{host}{path}'
.
format
(
host
=
self
.
site
.
domain
.
strip
(
'/'
),
path
=
object_path
)
self
.
assertIn
(
page_url
,
body
)
self
.
assertIn
(
page_url
,
body
)
def
test_studio_instance_with_course_team
(
self
):
def
test_studio_instance_with_course_team
(
self
):
...
@@ -3062,7 +3070,7 @@ class CourseRunEditViewTests(TestCase):
...
@@ -3062,7 +3070,7 @@ class CourseRunEditViewTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
class
CourseRevisionViewTests
(
TestCase
):
class
CourseRevisionViewTests
(
SiteMixin
,
TestCase
):
""" Tests for CourseReview"""
""" Tests for CourseReview"""
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -3114,7 +3122,7 @@ class CourseRevisionViewTests(TestCase):
...
@@ -3114,7 +3122,7 @@ class CourseRevisionViewTests(TestCase):
return
self
.
client
.
get
(
path
=
revision_path
)
return
self
.
client
.
get
(
path
=
revision_path
)
class
CreateRunFromDashboardViewTests
(
TestCase
):
class
CreateRunFromDashboardViewTests
(
SiteMixin
,
TestCase
):
""" Tests for the publisher `CreateRunFromDashboardView`. """
""" Tests for the publisher `CreateRunFromDashboardView`. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -3214,7 +3222,7 @@ class CreateRunFromDashboardViewTests(TestCase):
...
@@ -3214,7 +3222,7 @@ class CreateRunFromDashboardViewTests(TestCase):
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
self
.
assertEqual
(
str
(
mail
.
outbox
[
0
]
.
subject
),
expected_subject
)
class
CreateAdminImportCourseTest
(
TestCase
):
class
CreateAdminImportCourseTest
(
SiteMixin
,
TestCase
):
""" Tests for the publisher `CreateAdminImportCourse`. """
""" Tests for the publisher `CreateAdminImportCourse`. """
def
setUp
(
self
):
def
setUp
(
self
):
...
...
course_discovery/apps/publisher/views.py
View file @
e251012e
...
@@ -394,14 +394,16 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView):
...
@@ -394,14 +394,16 @@ class CourseEditView(mixins.PublisherPermissionMixin, UpdateView):
if
latest_run
and
latest_run
.
course_run_state
.
name
==
CourseRunStateChoices
.
Published
:
if
latest_run
and
latest_run
.
course_run_state
.
name
==
CourseRunStateChoices
.
Published
:
# If latest run of this course is published send an email to Publisher and don't change state.
# If latest run of this course is published send an email to Publisher and don't change state.
send_email_for_published_course_run_editing
(
latest_run
)
send_email_for_published_course_run_editing
(
latest_run
,
self
.
request
.
site
)
else
:
else
:
user_role
=
self
.
object
.
course_user_roles
.
get
(
user
=
user
)
user_role
=
self
.
object
.
course_user_roles
.
get
(
user
=
user
)
# Change course state to draft if marketing not yet reviewed or
# Change course state to draft if marketing not yet reviewed or
# if marketing person updating the course.
# if marketing person updating the course.
if
not
self
.
object
.
course_state
.
marketing_reviewed
or
user_role
.
role
==
PublisherUserRole
.
MarketingReviewer
:
if
not
self
.
object
.
course_state
.
marketing_reviewed
or
user_role
.
role
==
PublisherUserRole
.
MarketingReviewer
:
if
self
.
object
.
course_state
.
name
!=
CourseStateChoices
.
Draft
:
if
self
.
object
.
course_state
.
name
!=
CourseStateChoices
.
Draft
:
self
.
object
.
course_state
.
change_state
(
state
=
CourseStateChoices
.
Draft
,
user
=
user
)
self
.
object
.
course_state
.
change_state
(
state
=
CourseStateChoices
.
Draft
,
user
=
user
,
site
=
self
.
request
.
site
)
# Change ownership if user role not equal to owner role.
# Change ownership if user role not equal to owner role.
if
self
.
object
.
course_state
.
owner_role
!=
user_role
.
role
:
if
self
.
object
.
course_state
.
owner_role
!=
user_role
.
role
:
...
@@ -599,7 +601,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView):
...
@@ -599,7 +601,7 @@ class CreateCourseRunView(mixins.LoginRequiredMixin, CreateView):
)
)
messages
.
success
(
request
,
success_msg
)
messages
.
success
(
request
,
success_msg
)
emails
.
send_email_for_course_creation
(
parent_course
,
course_run
)
emails
.
send_email_for_course_creation
(
parent_course
,
course_run
,
request
.
site
)
return
HttpResponseRedirect
(
reverse
(
self
.
success_url
,
kwargs
=
{
'pk'
:
course_run
.
id
}))
return
HttpResponseRedirect
(
reverse
(
self
.
success_url
,
kwargs
=
{
'pk'
:
course_run
.
id
}))
except
Exception
as
error
:
# pylint: disable=broad-except
except
Exception
as
error
:
# pylint: disable=broad-except
# pylint: disable=no-member
# pylint: disable=no-member
...
@@ -740,10 +742,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
...
@@ -740,10 +742,10 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state
=
course_run
.
course_run_state
course_run_state
=
course_run
.
course_run_state
if
course_run_state
.
name
not
in
immutable_states
:
if
course_run_state
.
name
not
in
immutable_states
:
course_run_state
.
change_state
(
state
=
CourseStateChoices
.
Draft
,
user
=
user
)
course_run_state
.
change_state
(
state
=
CourseStateChoices
.
Draft
,
user
=
user
,
site
=
request
.
site
)
if
course_run
.
lms_course_id
and
lms_course_id
!=
course_run
.
lms_course_id
:
if
course_run
.
lms_course_id
and
lms_course_id
!=
course_run
.
lms_course_id
:
emails
.
send_email_for_studio_instance_created
(
course_run
)
emails
.
send_email_for_studio_instance_created
(
course_run
,
site
=
request
.
site
)
# pylint: disable=no-member
# pylint: disable=no-member
messages
.
success
(
request
,
_
(
'Course run updated successfully.'
))
messages
.
success
(
request
,
_
(
'Course run updated successfully.'
))
...
@@ -757,7 +759,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
...
@@ -757,7 +759,7 @@ class CourseRunEditView(mixins.LoginRequiredMixin, mixins.PublisherPermissionMix
course_run_state
.
change_owner_role
(
user_role
)
course_run_state
.
change_owner_role
(
user_role
)
if
CourseRunStateChoices
.
Published
==
course_run_state
.
name
:
if
CourseRunStateChoices
.
Published
==
course_run_state
.
name
:
send_email_for_published_course_run_editing
(
course_run
)
send_email_for_published_course_run_editing
(
course_run
,
request
.
site
)
return
HttpResponseRedirect
(
reverse
(
self
.
success_url
,
kwargs
=
{
'pk'
:
course_run
.
id
}))
return
HttpResponseRedirect
(
reverse
(
self
.
success_url
,
kwargs
=
{
'pk'
:
course_run
.
id
}))
except
Exception
as
e
:
# pylint: disable=broad-except
except
Exception
as
e
:
# pylint: disable=broad-except
...
...
course_discovery/apps/publisher_comments/api/tests/test_views.py
View file @
e251012e
...
@@ -3,6 +3,7 @@ import json
...
@@ -3,6 +3,7 @@ import json
from
django.test
import
TestCase
from
django.test
import
TestCase
from
rest_framework.reverse
import
reverse
from
rest_framework.reverse
import
reverse
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.publisher.tests
import
JSON_CONTENT_TYPE
from
course_discovery.apps.publisher.tests
import
JSON_CONTENT_TYPE
from
course_discovery.apps.publisher.tests.factories
import
CourseRunFactory
from
course_discovery.apps.publisher.tests.factories
import
CourseRunFactory
...
@@ -11,7 +12,7 @@ from course_discovery.apps.publisher_comments.models import Comments
...
@@ -11,7 +12,7 @@ from course_discovery.apps.publisher_comments.models import Comments
from
course_discovery.apps.publisher_comments.tests.factories
import
CommentFactory
from
course_discovery.apps.publisher_comments.tests.factories
import
CommentFactory
class
PostCommentTests
(
TestCase
):
class
PostCommentTests
(
SiteMixin
,
TestCase
):
def
generate_data
(
self
,
obj
):
def
generate_data
(
self
,
obj
):
"""Generate data for the form."""
"""Generate data for the form."""
...
@@ -39,7 +40,7 @@ class PostCommentTests(TestCase):
...
@@ -39,7 +40,7 @@ class PostCommentTests(TestCase):
self
.
assertEqual
(
comment
.
user_email
,
generated_data
[
'email'
])
self
.
assertEqual
(
comment
.
user_email
,
generated_data
[
'email'
])
class
UpdateCommentTests
(
TestCase
):
class
UpdateCommentTests
(
SiteMixin
,
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
UpdateCommentTests
,
self
)
.
setUp
()
super
(
UpdateCommentTests
,
self
)
.
setUp
()
...
...
course_discovery/apps/publisher_comments/tests/test_admin.py
View file @
e251012e
from
django.conf
import
settings
from
django.contrib.sites.models
import
Site
from
django.test
import
TestCase
from
django.test
import
TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.publisher.tests
import
factories
from
course_discovery.apps.publisher.tests
import
factories
from
course_discovery.apps.publisher_comments.forms
import
CommentsAdminForm
from
course_discovery.apps.publisher_comments.forms
import
CommentsAdminForm
from
course_discovery.apps.publisher_comments.tests.factories
import
CommentFactory
from
course_discovery.apps.publisher_comments.tests.factories
import
CommentFactory
class
AdminTests
(
TestCase
):
class
AdminTests
(
SiteMixin
,
TestCase
):
""" Tests Admin page and customize form."""
""" Tests Admin page and customize form."""
def
setUp
(
self
):
def
setUp
(
self
):
super
(
AdminTests
,
self
)
.
setUp
()
super
(
AdminTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
site
=
Site
.
objects
.
get
(
pk
=
settings
.
SITE_ID
)
self
.
course
=
factories
.
CourseFactory
()
self
.
course
=
factories
.
CourseFactory
()
self
.
comment
=
CommentFactory
(
content_object
=
self
.
course
,
user
=
self
.
user
,
site
=
self
.
site
)
self
.
comment
=
CommentFactory
(
content_object
=
self
.
course
,
user
=
self
.
user
,
site
=
self
.
site
)
...
...
course_discovery/apps/publisher_comments/tests/test_emails.py
View file @
e251012e
import
ddt
import
ddt
import
mock
import
mock
from
django.conf
import
settings
from
django.contrib.sites.models
import
Site
from
django.core
import
mail
from
django.core
import
mail
from
django.test
import
TestCase
from
django.test
import
TestCase
from
django.urls
import
reverse
from
django.urls
import
reverse
from
opaque_keys.edx.keys
import
CourseKey
from
opaque_keys.edx.keys
import
CourseKey
from
testfixtures
import
LogCapture
from
testfixtures
import
LogCapture
from
course_discovery.apps.api.tests.mixins
import
SiteMixin
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.publisher.choices
import
PublisherUserRole
from
course_discovery.apps.publisher.choices
import
PublisherUserRole
...
@@ -20,7 +19,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
...
@@ -20,7 +19,7 @@ from course_discovery.apps.publisher_comments.tests.factories import CommentFact
@ddt.ddt
@ddt.ddt
class
CommentsEmailTests
(
TestCase
):
class
CommentsEmailTests
(
SiteMixin
,
TestCase
):
""" Tests for the e-mail functionality for course, course-run and seats. """
""" Tests for the e-mail functionality for course, course-run and seats. """
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -30,8 +29,6 @@ class CommentsEmailTests(TestCase):
...
@@ -30,8 +29,6 @@ class CommentsEmailTests(TestCase):
self
.
user_2
=
UserFactory
()
self
.
user_2
=
UserFactory
()
self
.
user_3
=
UserFactory
()
self
.
user_3
=
UserFactory
()
self
.
site
=
Site
.
objects
.
get
(
pk
=
settings
.
SITE_ID
)
self
.
organization_extension
=
factories
.
OrganizationExtensionFactory
()
self
.
organization_extension
=
factories
.
OrganizationExtensionFactory
()
self
.
seat
=
factories
.
SeatFactory
()
self
.
seat
=
factories
.
SeatFactory
()
...
...
course_discovery/settings/base.py
View file @
e251012e
...
@@ -47,6 +47,7 @@ THIRD_PARTY_APPS = [
...
@@ -47,6 +47,7 @@ THIRD_PARTY_APPS = [
'django_fsm'
,
'django_fsm'
,
'storages'
,
'storages'
,
'django_comments'
,
'django_comments'
,
'django_sites_extensions'
,
'taggit'
,
'taggit'
,
'taggit_autosuggest'
,
'taggit_autosuggest'
,
'taggit_serializer'
,
'taggit_serializer'
,
...
@@ -79,6 +80,7 @@ MIDDLEWARE_CLASSES = (
...
@@ -79,6 +80,7 @@ MIDDLEWARE_CLASSES = (
'django.contrib.auth.middleware.AuthenticationMiddleware'
,
'django.contrib.auth.middleware.AuthenticationMiddleware'
,
'django.contrib.auth.middleware.SessionAuthenticationMiddleware'
,
'django.contrib.auth.middleware.SessionAuthenticationMiddleware'
,
'django.contrib.messages.middleware.MessageMiddleware'
,
'django.contrib.messages.middleware.MessageMiddleware'
,
'django.contrib.sites.middleware.CurrentSiteMiddleware'
,
'django.middleware.clickjacking.XFrameOptionsMiddleware'
,
'django.middleware.clickjacking.XFrameOptionsMiddleware'
,
'social_django.middleware.SocialAuthExceptionMiddleware'
,
'social_django.middleware.SocialAuthExceptionMiddleware'
,
'waffle.middleware.WaffleMiddleware'
,
'waffle.middleware.WaffleMiddleware'
,
...
@@ -476,7 +478,9 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20
...
@@ -476,7 +478,9 @@ DISTINCT_COUNTS_QUERY_CACHE_WARMING_COUNT = 20
DEFAULT_PARTNER_ID
=
None
DEFAULT_PARTNER_ID
=
None
# See: https://docs.djangoproject.com/en/dev/ref/settings/#site-id
# See: https://docs.djangoproject.com/en/dev/ref/settings/#site-id
# edx-django-sites-extensions will fallback to this site if we cannot identify the site from the hostname.
SITE_ID
=
1
SITE_ID
=
1
COMMENTS_APP
=
'course_discovery.apps.publisher_comments'
COMMENTS_APP
=
'course_discovery.apps.publisher_comments'
TAGGIT_CASE_INSENSITIVE
=
True
TAGGIT_CASE_INSENSITIVE
=
True
...
...
course_discovery/settings/test.py
View file @
e251012e
...
@@ -45,3 +45,7 @@ JWT_AUTH['JWT_SECRET_KEY'] = 'course-discovery-jwt-secret-key'
...
@@ -45,3 +45,7 @@ JWT_AUTH['JWT_SECRET_KEY'] = 'course-discovery-jwt-secret-key'
LOGGING
[
'handlers'
][
'local'
]
=
{
'class'
:
'logging.NullHandler'
}
LOGGING
[
'handlers'
][
'local'
]
=
{
'class'
:
'logging.NullHandler'
}
PUBLISHER_FROM_EMAIL
=
'test@example.com'
PUBLISHER_FROM_EMAIL
=
'test@example.com'
# Set to 0 to disable edx-django-sites-extensions to retrieve
# the site from cache and risk working with outdated information.
SITE_CACHE_TTL
=
0
requirements/base.txt
View file @
e251012e
...
@@ -32,6 +32,7 @@ dry-rest-permissions==0.1.6
...
@@ -32,6 +32,7 @@ dry-rest-permissions==0.1.6
edx-auth-backends==1.1.2
edx-auth-backends==1.1.2
edx-ccx-keys==0.2.0
edx-ccx-keys==0.2.0
edx-django-release-util==0.3.1
edx-django-release-util==0.3.1
edx-django-sites-extensions==2.3.0
edx-drf-extensions==1.2.3
edx-drf-extensions==1.2.3
edx-opaque-keys==0.3.1
edx-opaque-keys==0.3.1
edx-rest-api-client==1.6.0
edx-rest-api-client==1.6.0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment