Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
C
course-discovery
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
course-discovery
Commits
e79197f0
Commit
e79197f0
authored
Jul 07, 2017
by
Vedran Karacic
Committed by
Vedran Karačić
Jul 07, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Revert "Filter API data by partner"
This reverts commit
ddd3447e
.
parent
fc642742
Hide whitespace changes
Inline
Side-by-side
Showing
25 changed files
with
221 additions
and
177 deletions
+221
-177
course_discovery/apps/api/filters.py
+3
-1
course_discovery/apps/api/serializers.py
+21
-21
course_discovery/apps/api/tests/mixins.py
+0
-12
course_discovery/apps/api/tests/test_serializers.py
+4
-6
course_discovery/apps/api/v1/tests/test_views/mixins.py
+0
-6
course_discovery/apps/api/v1/tests/test_views/test_affiliate_window.py
+2
-2
course_discovery/apps/api/v1/tests/test_views/test_catalogs.py
+3
-2
course_discovery/apps/api/v1/tests/test_views/test_course_runs.py
+31
-3
course_discovery/apps/api/v1/tests/test_views/test_courses.py
+10
-9
course_discovery/apps/api/v1/tests/test_views/test_organizations.py
+17
-16
course_discovery/apps/api/v1/tests/test_views/test_people.py
+9
-3
course_discovery/apps/api/v1/tests/test_views/test_program_types.py
+2
-2
course_discovery/apps/api/v1/tests/test_views/test_programs.py
+34
-37
course_discovery/apps/api/v1/tests/test_views/test_search.py
+33
-6
course_discovery/apps/api/v1/views/__init__.py
+15
-0
course_discovery/apps/api/v1/views/course_runs.py
+10
-4
course_discovery/apps/api/v1/views/courses.py
+5
-6
course_discovery/apps/api/v1/views/organizations.py
+1
-4
course_discovery/apps/api/v1/views/people.py
+3
-2
course_discovery/apps/api/v1/views/programs.py
+2
-3
course_discovery/apps/api/v1/views/search.py
+3
-2
course_discovery/apps/core/tests/test_views.py
+1
-13
course_discovery/apps/core/views.py
+9
-13
course_discovery/apps/edx_catalog_extensions/api/v1/tests/test_views.py
+3
-3
course_discovery/settings/base.py
+0
-1
No files found.
course_discovery/apps/api/filters.py
View file @
e79197f0
import
logging
import
logging
from
django.conf
import
settings
from
django.contrib.auth
import
get_user_model
from
django.contrib.auth
import
get_user_model
from
django.db.models
import
QuerySet
from
django.db.models
import
QuerySet
from
django.utils.translation
import
ugettext
as
_
from
django.utils.translation
import
ugettext
as
_
...
@@ -12,6 +13,7 @@ from guardian.shortcuts import get_objects_for_user
...
@@ -12,6 +13,7 @@ from guardian.shortcuts import get_objects_for_user
from
rest_framework.exceptions
import
NotFound
,
PermissionDenied
from
rest_framework.exceptions
import
NotFound
,
PermissionDenied
from
course_discovery.apps.api.utils
import
cast2int
from
course_discovery.apps.api.utils
import
cast2int
from
course_discovery.apps.core.models
import
Partner
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Organization
,
Program
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Organization
,
Program
...
@@ -91,7 +93,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
...
@@ -91,7 +93,7 @@ class HaystackFilter(HaystackRequestFilterMixin, DefaultHaystackFilter):
# Return data for the default partner, if no partner is requested
# Return data for the default partner, if no partner is requested
if
not
any
(
field
in
filters
for
field
in
(
'partner'
,
'partner_exact'
)):
if
not
any
(
field
in
filters
for
field
in
(
'partner'
,
'partner_exact'
)):
filters
[
'partner'
]
=
request
.
site
.
partner
.
short_code
filters
[
'partner'
]
=
Partner
.
objects
.
get
(
pk
=
settings
.
DEFAULT_PARTNER_ID
)
.
short_code
return
filters
return
filters
...
...
course_discovery/apps/api/serializers.py
View file @
e79197f0
...
@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
...
@@ -369,8 +369,8 @@ class OrganizationSerializer(TaggitSerializer, MinimalOrganizationSerializer):
tags
=
TagListSerializerField
()
tags
=
TagListSerializerField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
partner
):
def
prefetch_queryset
(
cls
):
return
Organization
.
objects
.
filter
(
partner
=
partner
)
.
select_related
(
'partner'
)
.
prefetch_related
(
'tags'
)
return
Organization
.
objects
.
all
(
)
.
select_related
(
'partner'
)
.
prefetch_related
(
'tags'
)
class
Meta
(
MinimalOrganizationSerializer
.
Meta
):
class
Meta
(
MinimalOrganizationSerializer
.
Meta
):
fields
=
MinimalOrganizationSerializer
.
Meta
.
fields
+
(
fields
=
MinimalOrganizationSerializer
.
Meta
.
fields
+
(
...
@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
...
@@ -551,18 +551,18 @@ class CourseSerializer(MinimalCourseSerializer):
marketing_url
=
serializers
.
SerializerMethodField
()
marketing_url
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
,
partner
=
None
):
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
):
# Explicitly check for None to avoid returning all Courses when the
# Explicitly check for None to avoid returning all Courses when the
# queryset passed in happens to be empty.
# queryset passed in happens to be empty.
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
filter
(
partner
=
partner
)
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
all
(
)
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'expected_learning_items'
,
'expected_learning_items'
,
'prerequisites'
,
'prerequisites'
,
'subjects'
,
'subjects'
,
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
)
)
class
Meta
(
MinimalCourseSerializer
.
Meta
):
class
Meta
(
MinimalCourseSerializer
.
Meta
):
...
@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
...
@@ -586,20 +586,20 @@ class CourseWithProgramsSerializer(CourseSerializer):
programs
=
serializers
.
SerializerMethodField
()
programs
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
,
partner
=
None
):
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
):
"""
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
filtered CourseRun queryset.
"""
"""
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
filter
(
partner
=
partner
)
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
all
(
)
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'expected_learning_items'
,
'expected_learning_items'
,
'prerequisites'
,
'prerequisites'
,
'subjects'
,
'subjects'
,
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
)
)
def
get_course_runs
(
self
,
course
):
def
get_course_runs
(
self
,
course
):
...
@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
...
@@ -634,20 +634,20 @@ class CatalogCourseSerializer(CourseSerializer):
course_runs
=
serializers
.
SerializerMethodField
()
course_runs
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
,
partner
=
None
):
def
prefetch_queryset
(
cls
,
queryset
=
None
,
course_runs
=
None
):
"""
"""
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
Similar to the CourseSerializer's prefetch_queryset, but prefetches a
filtered CourseRun queryset.
filtered CourseRun queryset.
"""
"""
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
filter
(
partner
=
partner
)
queryset
=
queryset
if
queryset
is
not
None
else
Course
.
objects
.
all
(
)
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
queryset
.
select_related
(
'level_type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'expected_learning_items'
,
'expected_learning_items'
,
'prerequisites'
,
'prerequisites'
,
'subjects'
,
'subjects'
,
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'course_runs'
,
queryset
=
CourseRunSerializer
.
prefetch_queryset
(
queryset
=
course_runs
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'sponsoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
)
)
def
get_course_runs
(
self
,
course
):
def
get_course_runs
(
self
,
course
):
...
@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
...
@@ -703,8 +703,8 @@ class MinimalProgramSerializer(serializers.ModelSerializer):
type
=
serializers
.
SlugRelatedField
(
slug_field
=
'name'
,
queryset
=
ProgramType
.
objects
.
all
())
type
=
serializers
.
SlugRelatedField
(
slug_field
=
'name'
,
queryset
=
ProgramType
.
objects
.
all
())
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
partner
):
def
prefetch_queryset
(
cls
):
return
Program
.
objects
.
filter
(
partner
=
partner
)
.
select_related
(
'type'
,
'partner'
)
.
prefetch_related
(
return
Program
.
objects
.
all
(
)
.
select_related
(
'type'
,
'partner'
)
.
prefetch_related
(
'excluded_course_runs'
,
'excluded_course_runs'
,
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# `type` is serialized by a third-party serializer. Providing this field name allows us to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
# prefetch `applicable_seat_types`, a m2m on `ProgramType`, through `type`, a foreign key to
...
@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
...
@@ -828,7 +828,7 @@ class ProgramSerializer(MinimalProgramSerializer):
applicable_seat_types
=
serializers
.
SerializerMethodField
()
applicable_seat_types
=
serializers
.
SerializerMethodField
()
@classmethod
@classmethod
def
prefetch_queryset
(
cls
,
partner
):
def
prefetch_queryset
(
cls
):
"""
"""
Prefetch the related objects that will be serialized with a `Program`.
Prefetch the related objects that will be serialized with a `Program`.
...
@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
...
@@ -836,7 +836,7 @@ class ProgramSerializer(MinimalProgramSerializer):
chain of related fields from programs to course runs (i.e., we want control over
chain of related fields from programs to course runs (i.e., we want control over
the querysets that we're prefetching).
the querysets that we're prefetching).
"""
"""
return
Program
.
objects
.
filter
(
partner
=
partner
)
.
select_related
(
'type'
,
'video'
,
'partner'
)
.
prefetch_related
(
return
Program
.
objects
.
all
(
)
.
select_related
(
'type'
,
'video'
,
'partner'
)
.
prefetch_related
(
'excluded_course_runs'
,
'excluded_course_runs'
,
'expected_learning_items'
,
'expected_learning_items'
,
'faq'
,
'faq'
,
...
@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
...
@@ -847,9 +847,9 @@ class ProgramSerializer(MinimalProgramSerializer):
'type__applicable_seat_types'
,
'type__applicable_seat_types'
,
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
Prefetch
(
'courses'
,
queryset
=
CourseSerializer
.
prefetch_queryset
(
partner
=
partner
)),
Prefetch
(
'courses'
,
queryset
=
CourseSerializer
.
prefetch_queryset
()),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'authoring_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'credit_backing_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
(
partner
)),
Prefetch
(
'credit_backing_organizations'
,
queryset
=
OrganizationSerializer
.
prefetch_queryset
()),
Prefetch
(
'corporate_endorsements'
,
queryset
=
CorporateEndorsementSerializer
.
prefetch_queryset
()),
Prefetch
(
'corporate_endorsements'
,
queryset
=
CorporateEndorsementSerializer
.
prefetch_queryset
()),
Prefetch
(
'individual_endorsements'
,
queryset
=
EndorsementSerializer
.
prefetch_queryset
()),
Prefetch
(
'individual_endorsements'
,
queryset
=
EndorsementSerializer
.
prefetch_queryset
()),
)
)
...
...
course_discovery/apps/api/tests/mixins.py
deleted
100644 → 0
View file @
fc642742
from
django.conf
import
settings
from
django.contrib.sites.models
import
Site
from
course_discovery.apps.core.tests.factories
import
PartnerFactory
,
SiteFactory
class
PartnerMixin
(
object
):
def
setUp
(
self
):
super
(
PartnerMixin
,
self
)
.
setUp
()
Site
.
objects
.
all
()
.
delete
()
self
.
site
=
SiteFactory
(
id
=
settings
.
SITE_ID
)
self
.
partner
=
PartnerFactory
(
site
=
self
.
site
)
course_discovery/apps/api/tests/test_serializers.py
View file @
e79197f0
...
@@ -21,7 +21,6 @@ from course_discovery.apps.api.serializers import (
...
@@ -21,7 +21,6 @@ from course_discovery.apps.api.serializers import (
ProgramSerializer
,
ProgramTypeSerializer
,
SeatSerializer
,
SubjectSerializer
,
TypeaheadCourseRunSearchSerializer
,
ProgramSerializer
,
ProgramTypeSerializer
,
SeatSerializer
,
SubjectSerializer
,
TypeaheadCourseRunSearchSerializer
,
TypeaheadProgramSearchSerializer
,
VideoSerializer
TypeaheadProgramSearchSerializer
,
VideoSerializer
)
)
from
course_discovery.apps.api.tests.mixins
import
PartnerMixin
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.catalogs.tests.factories
import
CatalogFactory
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.models
import
User
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
...
@@ -97,7 +96,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
...
@@ -97,7 +96,7 @@ class CatalogSerializerTests(ElasticsearchTestMixin, TestCase):
self
.
assertEqual
(
User
.
objects
.
filter
(
username
=
username
)
.
count
(),
0
)
# pylint: disable=no-member
self
.
assertEqual
(
User
.
objects
.
filter
(
username
=
username
)
.
count
(),
0
)
# pylint: disable=no-member
class
MinimalCourseSerializerTests
(
PartnerMixin
,
TestCase
):
class
MinimalCourseSerializerTests
(
TestCase
):
serializer_class
=
MinimalCourseSerializer
serializer_class
=
MinimalCourseSerializer
def
get_expected_data
(
self
,
course
,
request
):
def
get_expected_data
(
self
,
course
,
request
):
...
@@ -114,8 +113,8 @@ class MinimalCourseSerializerTests(PartnerMixin, TestCase):
...
@@ -114,8 +113,8 @@ class MinimalCourseSerializerTests(PartnerMixin, TestCase):
def
test_data
(
self
):
def
test_data
(
self
):
request
=
make_request
()
request
=
make_request
()
organizations
=
OrganizationFactory
(
partner
=
self
.
partner
)
organizations
=
OrganizationFactory
()
course
=
CourseFactory
(
authoring_organizations
=
[
organizations
]
,
partner
=
self
.
partner
)
course
=
CourseFactory
(
authoring_organizations
=
[
organizations
])
CourseRunFactory
.
create_batch
(
2
,
course
=
course
)
CourseRunFactory
.
create_batch
(
2
,
course
=
course
)
serializer
=
self
.
serializer_class
(
course
,
context
=
{
'request'
:
request
})
serializer
=
self
.
serializer_class
(
course
,
context
=
{
'request'
:
request
})
expected
=
self
.
get_expected_data
(
course
,
request
)
expected
=
self
.
get_expected_data
(
course
,
request
)
...
@@ -178,10 +177,9 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
...
@@ -178,10 +177,9 @@ class CourseWithProgramsSerializerTests(CourseSerializerTests):
def
setUp
(
self
):
def
setUp
(
self
):
super
()
.
setUp
()
super
()
.
setUp
()
self
.
request
=
make_request
()
self
.
request
=
make_request
()
self
.
course
=
CourseFactory
(
partner
=
self
.
partner
)
self
.
course
=
CourseFactory
()
self
.
deleted_program
=
ProgramFactory
(
self
.
deleted_program
=
ProgramFactory
(
courses
=
[
self
.
course
],
courses
=
[
self
.
course
],
partner
=
self
.
partner
,
status
=
ProgramStatus
.
Deleted
status
=
ProgramStatus
.
Deleted
)
)
...
...
course_discovery/apps/api/v1/tests/test_views/mixins.py
View file @
e79197f0
...
@@ -4,7 +4,6 @@ import json
...
@@ -4,7 +4,6 @@ import json
import
responses
import
responses
from
django.conf
import
settings
from
django.conf
import
settings
from
rest_framework.test
import
APITestCase
as
RestAPITestCase
from
rest_framework.test
import
APIRequestFactory
from
rest_framework.test
import
APIRequestFactory
from
course_discovery.apps.api.serializers
import
(
from
course_discovery.apps.api.serializers
import
(
...
@@ -12,7 +11,6 @@ from course_discovery.apps.api.serializers import (
...
@@ -12,7 +11,6 @@ from course_discovery.apps.api.serializers import (
CourseWithProgramsSerializer
,
FlattenedCourseRunWithCourseSerializer
,
MinimalProgramSerializer
,
CourseWithProgramsSerializer
,
FlattenedCourseRunWithCourseSerializer
,
MinimalProgramSerializer
,
OrganizationSerializer
,
PersonSerializer
,
ProgramSerializer
,
ProgramTypeSerializer
OrganizationSerializer
,
PersonSerializer
,
ProgramSerializer
,
ProgramTypeSerializer
)
)
from
course_discovery.apps.api.tests.mixins
import
PartnerMixin
class
SerializationMixin
(
object
):
class
SerializationMixin
(
object
):
...
@@ -90,7 +88,3 @@ class OAuth2Mixin(object):
...
@@ -90,7 +88,3 @@ class OAuth2Mixin(object):
content_type
=
'application/json'
,
content_type
=
'application/json'
,
status
=
status
status
=
status
)
)
class
APITestCase
(
PartnerMixin
,
RestAPITestCase
):
pass
course_discovery/apps/api/v1/tests/test_views/test_affiliate_window.py
View file @
e79197f0
...
@@ -46,7 +46,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
...
@@ -46,7 +46,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
def
test_affiliate_with_supported_seats
(
self
):
def
test_affiliate_with_supported_seats
(
self
):
""" Verify that endpoint returns course runs for verified and professional seats only. """
""" Verify that endpoint returns course runs for verified and professional seats only. """
with
self
.
assertNumQueries
(
9
):
with
self
.
assertNumQueries
(
8
):
response
=
self
.
client
.
get
(
self
.
affiliate_url
)
response
=
self
.
client
.
get
(
self
.
affiliate_url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -130,7 +130,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
...
@@ -130,7 +130,7 @@ class AffiliateWindowViewSetTests(ElasticsearchTestMixin, SerializationMixin, AP
# Superusers can view all catalogs
# Superusers can view all catalogs
self
.
client
.
force_authenticate
(
superuser
)
self
.
client
.
force_authenticate
(
superuser
)
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
4
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
...
course_discovery/apps/api/v1/tests/test_views/test_catalogs.py
View file @
e79197f0
...
@@ -185,7 +185,8 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
...
@@ -185,7 +185,8 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
# Any course appearing in the response must have at least one serialized run.
# Any course appearing in the response must have at least one serialized run.
assert
len
(
response
.
data
[
'results'
][
0
][
'course_runs'
])
>
0
assert
len
(
response
.
data
[
'results'
][
0
][
'course_runs'
])
>
0
else
:
else
:
response
=
self
.
client
.
get
(
url
)
with
self
.
assertNumQueries
(
3
):
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
assert
response
.
data
[
'results'
]
==
[]
assert
response
.
data
[
'results'
]
==
[]
...
@@ -217,7 +218,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
...
@@ -217,7 +218,7 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
url
=
reverse
(
'api:v1:catalog-csv'
,
kwargs
=
{
'id'
:
self
.
catalog
.
id
})
url
=
reverse
(
'api:v1:catalog-csv'
,
kwargs
=
{
'id'
:
self
.
catalog
.
id
})
with
self
.
assertNumQueries
(
1
8
):
with
self
.
assertNumQueries
(
1
7
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
course_run
=
self
.
serialize_catalog_flat_course_run
(
self
.
course_run
)
course_run
=
self
.
serialize_catalog_flat_course_run
(
self
.
course_run
)
...
...
course_discovery/apps/api/v1/tests/test_views/test_course_runs.py
View file @
e79197f0
...
@@ -4,16 +4,19 @@ import urllib
...
@@ -4,16 +4,19 @@ import urllib
import
ddt
import
ddt
import
pytz
import
pytz
from
django.conf
import
settings
from
django.db.models.functions
import
Lower
from
django.db.models.functions
import
Lower
from
rest_framework.reverse
import
reverse
from
rest_framework.reverse
import
reverse
from
rest_framework.test
import
APIRequestFactory
from
rest_framework.test
import
APIRequestFactory
,
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
CourseRun
from
course_discovery.apps.course_metadata.models
import
CourseRun
from
course_discovery.apps.course_metadata.tests.factories
import
CourseRunFactory
,
ProgramFactory
,
SeatFactory
from
course_discovery.apps.course_metadata.tests.factories
import
(
CourseRunFactory
,
PartnerFactory
,
ProgramFactory
,
SeatFactory
)
@ddt.ddt
@ddt.ddt
...
@@ -22,6 +25,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
...
@@ -22,6 +25,10 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
super
(
CourseRunViewSetTests
,
self
)
.
setUp
()
super
(
CourseRunViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
client
.
force_authenticate
(
self
.
user
)
self
.
client
.
force_authenticate
(
self
.
user
)
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self
.
partner
=
PartnerFactory
(
id
=
settings
.
DEFAULT_PARTNER_ID
)
self
.
course_run
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
course_run
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
course_run_2
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
course_run_2
=
CourseRunFactory
(
course__partner
=
self
.
partner
)
self
.
refresh_index
()
self
.
refresh_index
()
...
@@ -163,6 +170,15 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
...
@@ -163,6 +170,15 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
key
=
lambda
course_run
:
course_run
[
'key'
])
key
=
lambda
course_run
:
course_run
[
'key'
])
self
.
assertListEqual
(
actual_sorted
,
expected_sorted
)
self
.
assertListEqual
(
actual_sorted
,
expected_sorted
)
def
test_list_query_invalid_partner
(
self
):
""" Verify the endpoint returns an 400 BAD_REQUEST if an invalid partner is sent """
query
=
'title:Some random title'
url
=
'{root}?q={query}&partner={partner}'
.
format
(
root
=
reverse
(
'api:v1:course_run-list'
),
query
=
query
,
partner
=
'foo'
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
400
)
def
assert_list_results
(
self
,
url
,
expected
,
extra_context
=
None
):
def
assert_list_results
(
self
,
url
,
expected
,
extra_context
=
None
):
expected
=
sorted
(
expected
,
key
=
lambda
course_run
:
course_run
.
key
.
lower
())
expected
=
sorted
(
expected
,
key
=
lambda
course_run
:
course_run
.
key
.
lower
())
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
...
@@ -252,6 +268,18 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
...
@@ -252,6 +268,18 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
}
}
)
)
def
test_contains_single_course_run_invalid_partner
(
self
):
""" Verify that a 400 BAD_REQUEST is thrown when passing an invalid partner """
qs
=
urllib
.
parse
.
urlencode
({
'query'
:
'id:course*'
,
'course_run_ids'
:
self
.
course_run
.
key
,
'partner'
:
'foo'
})
url
=
'{}?{}'
.
format
(
reverse
(
'api:v1:course_run-contains'
),
qs
)
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
400
def
test_contains_multiple_course_runs
(
self
):
def
test_contains_multiple_course_runs
(
self
):
qs
=
urllib
.
parse
.
urlencode
({
qs
=
urllib
.
parse
.
urlencode
({
'query'
:
'id:course*'
,
'query'
:
'id:course*'
,
...
...
course_discovery/apps/api/v1/tests/test_views/test_courses.py
View file @
e79197f0
...
@@ -4,8 +4,9 @@ import ddt
...
@@ -4,8 +4,9 @@ import ddt
import
pytz
import
pytz
from
django.db.models.functions
import
Lower
from
django.db.models.functions
import
Lower
from
rest_framework.reverse
import
reverse
from
rest_framework.reverse
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
Course
from
course_discovery.apps.course_metadata.models
import
Course
...
@@ -22,13 +23,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -22,13 +23,13 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
super
(
CourseViewSetTests
,
self
)
.
setUp
()
super
(
CourseViewSetTests
,
self
)
.
setUp
()
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
user
=
UserFactory
(
is_staff
=
True
,
is_superuser
=
True
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
client
.
login
(
username
=
self
.
user
.
username
,
password
=
USER_PASSWORD
)
self
.
course
=
CourseFactory
(
partner
=
self
.
partner
)
self
.
course
=
CourseFactory
()
def
test_get
(
self
):
def
test_get
(
self
):
""" Verify the endpoint returns the details for a single course. """
""" Verify the endpoint returns the details for a single course. """
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
with
self
.
assertNumQueries
(
20
):
with
self
.
assertNumQueries
(
18
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
data
,
self
.
serialize_course
(
self
.
course
))
self
.
assertEqual
(
response
.
data
,
self
.
serialize_course
(
self
.
course
))
...
@@ -37,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -37,7 +38,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns no deleted associated programs """
""" Verify the endpoint returns no deleted associated programs """
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
with
self
.
assertNumQueries
(
1
3
):
with
self
.
assertNumQueries
(
1
1
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
data
.
get
(
'programs'
),
[])
self
.
assertEqual
(
response
.
data
.
get
(
'programs'
),
[])
...
@@ -50,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -50,7 +51,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
ProgramFactory
(
courses
=
[
self
.
course
],
status
=
ProgramStatus
.
Deleted
)
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
=
reverse
(
'api:v1:course-detail'
,
kwargs
=
{
'key'
:
self
.
course
.
key
})
url
+=
'?include_deleted_programs=1'
url
+=
'?include_deleted_programs=1'
with
self
.
assertNumQueries
(
2
4
):
with
self
.
assertNumQueries
(
2
2
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
self
.
assertEqual
(
...
@@ -186,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -186,7 +187,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all courses. """
""" Verify the endpoint returns a list of all courses. """
url
=
reverse
(
'api:v1:course-list'
)
url
=
reverse
(
'api:v1:course-list'
)
with
self
.
assertNumQueries
(
2
6
):
with
self
.
assertNumQueries
(
2
4
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertListEqual
(
self
.
assertListEqual
(
...
@@ -202,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
...
@@ -202,18 +203,18 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
query
=
'title:'
+
title
query
=
'title:'
+
title
url
=
'{root}?q={query}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
query
=
query
)
url
=
'{root}?q={query}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
query
=
query
)
with
self
.
assertNumQueries
(
3
9
):
with
self
.
assertNumQueries
(
3
7
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
def
test_list_key_filter
(
self
):
def
test_list_key_filter
(
self
):
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
""" Verify the endpoint returns a list of courses filtered by the specified keys. """
courses
=
CourseFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
courses
=
CourseFactory
.
create_batch
(
3
)
courses
=
sorted
(
courses
,
key
=
lambda
course
:
course
.
key
.
lower
())
courses
=
sorted
(
courses
,
key
=
lambda
course
:
course
.
key
.
lower
())
keys
=
','
.
join
([
course
.
key
for
course
in
courses
])
keys
=
','
.
join
([
course
.
key
for
course
in
courses
])
url
=
'{root}?keys={keys}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
keys
=
keys
)
url
=
'{root}?keys={keys}'
.
format
(
root
=
reverse
(
'api:v1:course-list'
),
keys
=
keys
)
with
self
.
assertNumQueries
(
3
9
):
with
self
.
assertNumQueries
(
3
7
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
self
.
assertListEqual
(
response
.
data
[
'results'
],
self
.
serialize_course
(
courses
,
many
=
True
))
...
...
course_discovery/apps/api/v1/tests/test_views/test_organizations.py
View file @
e79197f0
import
uuid
import
uuid
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.course_metadata.tests.factories
import
Organization
,
OrganizationFactory
from
course_discovery.apps.course_metadata.tests.factories
import
Organization
,
OrganizationFactory
...
@@ -26,19 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -26,19 +27,17 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
assert_response_data_valid
(
self
,
response
,
organizations
,
many
=
True
):
def
assert_response_data_valid
(
self
,
response
,
organizations
,
many
=
True
):
""" Asserts the response data (only) contains the expected organizations. """
""" Asserts the response data (only) contains the expected organizations. """
actual
=
response
.
data
serializer_data
=
self
.
serialize_organization
(
organizations
,
many
=
many
)
actual
=
response
.
data
if
many
:
if
many
:
actual
=
actual
[
'results'
]
actual
=
actual
[
'results'
]
actual
=
sorted
(
actual
,
key
=
lambda
k
:
k
[
'uuid'
])
serializer_data
=
sorted
(
serializer_data
,
key
=
lambda
k
:
k
[
'uuid'
])
self
.
assertEqual
(
actual
,
se
rializer_data
)
self
.
assertEqual
(
actual
,
se
lf
.
serialize_organization
(
organizations
,
many
=
many
)
)
def
assert_list_uuid_filter
(
self
,
organizations
,
expected_query_count
):
def
assert_list_uuid_filter
(
self
,
organizations
):
""" Asserts the list endpoint supports filtering by UUID. """
""" Asserts the list endpoint supports filtering by UUID. """
with
self
.
assertNumQueries
(
expected_query_count
):
with
self
.
assertNumQueries
(
5
):
uuids
=
','
.
join
([
organization
.
uuid
.
hex
for
organization
in
organizations
])
uuids
=
','
.
join
([
organization
.
uuid
.
hex
for
organization
in
organizations
])
url
=
'{root}?uuids={uuids}'
.
format
(
root
=
self
.
list_path
,
uuids
=
uuids
)
url
=
'{root}?uuids={uuids}'
.
format
(
root
=
self
.
list_path
,
uuids
=
uuids
)
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
...
@@ -48,6 +47,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -48,6 +47,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
assert_list_tag_filter
(
self
,
organizations
,
tags
,
expected_query_count
=
5
):
def
assert_list_tag_filter
(
self
,
organizations
,
tags
,
expected_query_count
=
5
):
""" Asserts the list endpoint supports filtering by tags. """
""" Asserts the list endpoint supports filtering by tags. """
with
self
.
assertNumQueries
(
expected_query_count
):
with
self
.
assertNumQueries
(
expected_query_count
):
tags
=
','
.
join
(
tags
)
tags
=
','
.
join
(
tags
)
url
=
'{root}?tags={tags}'
.
format
(
root
=
self
.
list_path
,
tags
=
tags
)
url
=
'{root}?tags={tags}'
.
format
(
root
=
self
.
list_path
,
tags
=
tags
)
...
@@ -58,9 +58,10 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -58,9 +58,10 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
test_list
(
self
):
def
test_list
(
self
):
""" Verify the endpoint returns a list of all organizations. """
""" Verify the endpoint returns a list of all organizations. """
OrganizationFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
with
self
.
assertNumQueries
(
7
):
OrganizationFactory
.
create_batch
(
3
)
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
self
.
list_path
)
response
=
self
.
client
.
get
(
self
.
list_path
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -69,22 +70,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -69,22 +70,22 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
test_list_uuid_filter
(
self
):
def
test_list_uuid_filter
(
self
):
""" Verify the endpoint returns a list of organizations filtered by UUID. """
""" Verify the endpoint returns a list of organizations filtered by UUID. """
organizations
=
OrganizationFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
organizations
=
OrganizationFactory
.
create_batch
(
3
)
# Test with a single UUID
# Test with a single UUID
self
.
assert_list_uuid_filter
([
organizations
[
0
]]
,
7
)
self
.
assert_list_uuid_filter
([
organizations
[
0
]])
# Test with multiple UUIDs
# Test with multiple UUIDs
self
.
assert_list_uuid_filter
(
organizations
,
5
)
self
.
assert_list_uuid_filter
(
organizations
)
def
test_list_tag_filter
(
self
):
def
test_list_tag_filter
(
self
):
""" Verify the endpoint returns a list of organizations filtered by tag. """
""" Verify the endpoint returns a list of organizations filtered by tag. """
tag
=
'test-org'
tag
=
'test-org'
organizations
=
OrganizationFactory
.
create_batch
(
2
,
partner
=
self
.
partner
)
organizations
=
OrganizationFactory
.
create_batch
(
2
)
# If no organizations have been tagged, the endpoint should not return any data
# If no organizations have been tagged, the endpoint should not return any data
self
.
assert_list_tag_filter
([],
[
tag
],
expected_query_count
=
6
)
self
.
assert_list_tag_filter
([],
[
tag
],
expected_query_count
=
4
)
# Tagged organizations should be returned
# Tagged organizations should be returned
organizations
[
0
]
.
tags
.
add
(
tag
)
organizations
[
0
]
.
tags
.
add
(
tag
)
...
@@ -98,7 +99,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
...
@@ -98,7 +99,7 @@ class OrganizationViewSetTests(SerializationMixin, APITestCase):
def
test_retrieve
(
self
):
def
test_retrieve
(
self
):
""" Verify the endpoint returns details for a single organization. """
""" Verify the endpoint returns details for a single organization. """
organization
=
OrganizationFactory
(
partner
=
self
.
partner
)
organization
=
OrganizationFactory
()
url
=
reverse
(
'api:v1:organization-detail'
,
kwargs
=
{
'uuid'
:
organization
.
uuid
})
url
=
reverse
(
'api:v1:organization-detail'
,
kwargs
=
{
'uuid'
:
organization
.
uuid
})
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
...
...
course_discovery/apps/api/v1/tests/test_views/test_people.py
View file @
e79197f0
# pylint: disable=redefined-builtin,no-member
# pylint: disable=redefined-builtin,no-member
import
ddt
import
ddt
from
django.conf
import
settings
from
django.contrib.auth
import
get_user_model
from
django.contrib.auth
import
get_user_model
from
django.db
import
IntegrityError
from
django.db
import
IntegrityError
from
mock
import
mock
from
mock
import
mock
...
@@ -7,19 +8,20 @@ from rest_framework.reverse import reverse
...
@@ -7,19 +8,20 @@ from rest_framework.reverse import reverse
from
rest_framework.test
import
APITestCase
from
rest_framework.test
import
APITestCase
from
testfixtures
import
LogCapture
from
testfixtures
import
LogCapture
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
PartnerMixin
,
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.views.people
import
logger
as
people_logger
from
course_discovery.apps.api.v1.views.people
import
logger
as
people_logger
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.core.tests.factories
import
UserFactory
from
course_discovery.apps.course_metadata.models
import
Person
from
course_discovery.apps.course_metadata.models
import
Person
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests
import
toggle_switch
from
course_discovery.apps.course_metadata.tests.factories
import
OrganizationFactory
,
PersonFactory
,
PositionFactory
from
course_discovery.apps.course_metadata.tests.factories
import
(
OrganizationFactory
,
PartnerFactory
,
PersonFactory
,
PositionFactory
)
User
=
get_user_model
()
User
=
get_user_model
()
@ddt.ddt
@ddt.ddt
class
PersonViewSetTests
(
SerializationMixin
,
PartnerMixin
,
APITestCase
):
class
PersonViewSetTests
(
SerializationMixin
,
APITestCase
):
""" Tests for the person resource. """
""" Tests for the person resource. """
people_list_url
=
reverse
(
'api:v1:person-list'
)
people_list_url
=
reverse
(
'api:v1:person-list'
)
...
@@ -30,6 +32,10 @@ class PersonViewSetTests(SerializationMixin, PartnerMixin, APITestCase):
...
@@ -30,6 +32,10 @@ class PersonViewSetTests(SerializationMixin, PartnerMixin, APITestCase):
self
.
person
=
PersonFactory
()
self
.
person
=
PersonFactory
()
PositionFactory
(
person
=
self
.
person
)
PositionFactory
(
person
=
self
.
person
)
self
.
organization
=
OrganizationFactory
()
self
.
organization
=
OrganizationFactory
()
# DEFAULT_PARTNER_ID is used explicitly here to avoid issues with differences in
# auto-incrementing behavior across databases. Otherwise, it's not safe to assume
# that the partner created here will always have id=DEFAULT_PARTNER_ID.
self
.
partner
=
PartnerFactory
(
id
=
settings
.
DEFAULT_PARTNER_ID
)
toggle_switch
(
'publish_person_to_marketing_site'
,
True
)
toggle_switch
(
'publish_person_to_marketing_site'
,
True
)
self
.
expected_node
=
{
self
.
expected_node
=
{
'resource'
:
'node'
,
''
'resource'
:
'node'
,
''
...
...
course_discovery/apps/api/v1/tests/test_views/test_program_types.py
View file @
e79197f0
...
@@ -28,7 +28,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
...
@@ -28,7 +28,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all program types. """
""" Verify the endpoint returns a list of all program types. """
ProgramTypeFactory
.
create_batch
(
4
)
ProgramTypeFactory
.
create_batch
(
4
)
expected
=
ProgramType
.
objects
.
all
()
expected
=
ProgramType
.
objects
.
all
()
with
self
.
assertNumQueries
(
6
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
self
.
list_path
)
response
=
self
.
client
.
get
(
self
.
list_path
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
@@ -39,7 +39,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
...
@@ -39,7 +39,7 @@ class ProgramTypeViewSetTests(SerializationMixin, APITestCase):
program_type
=
ProgramTypeFactory
()
program_type
=
ProgramTypeFactory
()
url
=
reverse
(
'api:v1:program_type-detail'
,
kwargs
=
{
'slug'
:
program_type
.
slug
})
url
=
reverse
(
'api:v1:program_type-detail'
,
kwargs
=
{
'slug'
:
program_type
.
slug
})
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
4
):
response
=
self
.
client
.
get
(
url
)
response
=
self
.
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
...
course_discovery/apps/api/v1/tests/test_views/test_programs.py
View file @
e79197f0
...
@@ -3,9 +3,10 @@ import urllib.parse
...
@@ -3,9 +3,10 @@ import urllib.parse
import
ddt
import
ddt
from
django.core.cache
import
cache
from
django.core.cache
import
cache
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.serializers
import
MinimalProgramSerializer
from
course_discovery.apps.api.serializers
import
MinimalProgramSerializer
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
,
SerializationMixin
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
SerializationMixin
from
course_discovery.apps.api.v1.views.programs
import
ProgramViewSet
from
course_discovery.apps.api.v1.views.programs
import
ProgramViewSet
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
UserFactory
from
course_discovery.apps.core.tests.helpers
import
make_image_file
from
course_discovery.apps.core.tests.helpers
import
make_image_file
...
@@ -30,10 +31,10 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -30,10 +31,10 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
cache
.
clear
()
cache
.
clear
()
def
create_program
(
self
):
def
create_program
(
self
):
organizations
=
[
OrganizationFactory
(
partner
=
self
.
partner
)]
organizations
=
[
OrganizationFactory
()]
person
=
PersonFactory
()
person
=
PersonFactory
()
course
=
CourseFactory
(
partner
=
self
.
partner
)
course
=
CourseFactory
()
CourseRunFactory
(
course
=
course
,
staff
=
[
person
])
CourseRunFactory
(
course
=
course
,
staff
=
[
person
])
program
=
ProgramFactory
(
program
=
ProgramFactory
(
...
@@ -45,8 +46,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -45,8 +46,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
expected_learning_items
=
ExpectedLearningItemFactory
.
create_batch
(
1
),
expected_learning_items
=
ExpectedLearningItemFactory
.
create_batch
(
1
),
job_outlook_items
=
JobOutlookItemFactory
.
create_batch
(
1
),
job_outlook_items
=
JobOutlookItemFactory
.
create_batch
(
1
),
banner_image
=
make_image_file
(
'test_banner.jpg'
),
banner_image
=
make_image_file
(
'test_banner.jpg'
),
video
=
VideoFactory
(),
video
=
VideoFactory
()
partner
=
self
.
partner
)
)
return
program
return
program
...
@@ -73,7 +73,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -73,7 +73,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
test_retrieve
(
self
):
def
test_retrieve
(
self
):
""" Verify the endpoint returns the details for a single program. """
""" Verify the endpoint returns the details for a single program. """
program
=
self
.
create_program
()
program
=
self
.
create_program
()
with
self
.
assertNumQueries
(
3
9
):
with
self
.
assertNumQueries
(
3
7
):
response
=
self
.
assert_retrieve_success
(
program
)
response
=
self
.
assert_retrieve_success
(
program
)
# property does not have the right values while being indexed
# property does not have the right values while being indexed
del
program
.
_course_run_weeks_to_complete
del
program
.
_course_run_weeks_to_complete
...
@@ -90,25 +90,22 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -90,25 +90,22 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
@ddt.data
(
True
,
False
)
@ddt.data
(
True
,
False
)
def
test_retrieve_with_sorting_flag
(
self
,
order_courses_by_start_date
):
def
test_retrieve_with_sorting_flag
(
self
,
order_courses_by_start_date
):
""" Verify the number of queries is the same with sorting flag set to true. """
""" Verify the number of queries is the same with sorting flag set to true. """
course_list
=
CourseFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
course_list
=
CourseFactory
.
create_batch
(
3
)
for
course
in
course_list
:
for
course
in
course_list
:
CourseRunFactory
(
course
=
course
)
CourseRunFactory
(
course
=
course
)
program
=
ProgramFactory
(
program
=
ProgramFactory
(
courses
=
course_list
,
order_courses_by_start_date
=
order_courses_by_start_date
)
courses
=
course_list
,
order_courses_by_start_date
=
order_courses_by_start_date
,
partner
=
self
.
partner
)
# property does not have the right values while being indexed
# property does not have the right values while being indexed
del
program
.
_course_run_weeks_to_complete
del
program
.
_course_run_weeks_to_complete
with
self
.
assertNumQueries
(
2
8
):
with
self
.
assertNumQueries
(
2
6
):
response
=
self
.
assert_retrieve_success
(
program
)
response
=
self
.
assert_retrieve_success
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
self
.
assertEqual
(
course_list
,
list
(
program
.
courses
.
all
()))
# pylint: disable=no-member
self
.
assertEqual
(
course_list
,
list
(
program
.
courses
.
all
()))
# pylint: disable=no-member
def
test_retrieve_without_course_runs
(
self
):
def
test_retrieve_without_course_runs
(
self
):
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
""" Verify the endpoint returns data for a program even if the program's courses have no course runs. """
course
=
CourseFactory
(
partner
=
self
.
partner
)
course
=
CourseFactory
()
program
=
ProgramFactory
(
courses
=
[
course
]
,
partner
=
self
.
partner
)
program
=
ProgramFactory
(
courses
=
[
course
])
with
self
.
assertNumQueries
(
2
2
):
with
self
.
assertNumQueries
(
2
0
):
response
=
self
.
assert_retrieve_success
(
program
)
response
=
self
.
assert_retrieve_success
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
assert
response
.
data
==
self
.
serialize_program
(
program
)
...
@@ -138,7 +135,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -138,7 +135,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns a list of all programs. """
""" Verify the endpoint returns a list of all programs. """
expected
=
[
self
.
create_program
()
for
__
in
range
(
3
)]
expected
=
[
self
.
create_program
()
for
__
in
range
(
3
)]
expected
.
reverse
()
expected
.
reverse
()
self
.
assert_list_results
(
self
.
list_path
,
expected
,
1
4
)
self
.
assert_list_results
(
self
.
list_path
,
expected
,
1
2
)
# Verify that repeated list requests use the cache.
# Verify that repeated list requests use the cache.
self
.
assert_list_results
(
self
.
list_path
,
expected
,
2
)
self
.
assert_list_results
(
self
.
list_path
,
expected
,
2
)
...
@@ -148,8 +145,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -148,8 +145,8 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
Verify that the list view returns a simply list of UUIDs when the
Verify that the list view returns a simply list of UUIDs when the
uuids_only query parameter is passed.
uuids_only query parameter is passed.
"""
"""
active
=
ProgramFactory
.
create_batch
(
3
,
partner
=
self
.
partner
)
active
=
ProgramFactory
.
create_batch
(
3
)
retired
=
[
ProgramFactory
(
status
=
ProgramStatus
.
Retired
,
partner
=
self
.
partner
)]
retired
=
[
ProgramFactory
(
status
=
ProgramStatus
.
Retired
)]
programs
=
active
+
retired
programs
=
active
+
retired
querystring
=
{
'uuids_only'
:
1
}
querystring
=
{
'uuids_only'
:
1
}
...
@@ -168,47 +165,47 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -168,47 +165,47 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
test_filter_by_type
(
self
):
def
test_filter_by_type
(
self
):
""" Verify that the endpoint filters programs to those of a given type. """
""" Verify that the endpoint filters programs to those of a given type. """
program_type_name
=
'foo'
program_type_name
=
'foo'
program
=
ProgramFactory
(
type__name
=
program_type_name
,
partner
=
self
.
partner
)
program
=
ProgramFactory
(
type__name
=
program_type_name
)
url
=
self
.
list_path
+
'?type='
+
program_type_name
url
=
self
.
list_path
+
'?type='
+
program_type_name
self
.
assert_list_results
(
url
,
[
program
],
10
)
self
.
assert_list_results
(
url
,
[
program
],
8
)
url
=
self
.
list_path
+
'?type=bar'
url
=
self
.
list_path
+
'?type=bar'
self
.
assert_list_results
(
url
,
[],
4
)
self
.
assert_list_results
(
url
,
[],
4
)
def
test_filter_by_types
(
self
):
def
test_filter_by_types
(
self
):
""" Verify that the endpoint filters programs to those matching the provided ProgramType slugs. """
""" Verify that the endpoint filters programs to those matching the provided ProgramType slugs. """
expected
=
ProgramFactory
.
create_batch
(
2
,
partner
=
self
.
partner
)
expected
=
ProgramFactory
.
create_batch
(
2
)
expected
.
reverse
()
expected
.
reverse
()
type_slugs
=
[
p
.
type
.
slug
for
p
in
expected
]
type_slugs
=
[
p
.
type
.
slug
for
p
in
expected
]
url
=
self
.
list_path
+
'?types='
+
','
.
join
(
type_slugs
)
url
=
self
.
list_path
+
'?types='
+
','
.
join
(
type_slugs
)
# Create a third program, which should be filtered out.
# Create a third program, which should be filtered out.
ProgramFactory
(
partner
=
self
.
partner
)
ProgramFactory
()
self
.
assert_list_results
(
url
,
expected
,
10
)
self
.
assert_list_results
(
url
,
expected
,
8
)
def
test_filter_by_uuids
(
self
):
def
test_filter_by_uuids
(
self
):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """
expected
=
ProgramFactory
.
create_batch
(
2
,
partner
=
self
.
partner
)
expected
=
ProgramFactory
.
create_batch
(
2
)
expected
.
reverse
()
expected
.
reverse
()
uuids
=
[
str
(
p
.
uuid
)
for
p
in
expected
]
uuids
=
[
str
(
p
.
uuid
)
for
p
in
expected
]
url
=
self
.
list_path
+
'?uuids='
+
','
.
join
(
uuids
)
url
=
self
.
list_path
+
'?uuids='
+
','
.
join
(
uuids
)
# Create a third program, which should be filtered out.
# Create a third program, which should be filtered out.
ProgramFactory
(
partner
=
self
.
partner
)
ProgramFactory
()
self
.
assert_list_results
(
url
,
expected
,
10
)
self
.
assert_list_results
(
url
,
expected
,
8
)
@ddt.data
(
@ddt.data
(
(
ProgramStatus
.
Unpublished
,
False
,
6
),
(
ProgramStatus
.
Unpublished
,
False
,
4
),
(
ProgramStatus
.
Active
,
True
,
10
),
(
ProgramStatus
.
Active
,
True
,
8
),
)
)
@ddt.unpack
@ddt.unpack
def
test_filter_by_marketable
(
self
,
status
,
is_marketable
,
expected_query_count
):
def
test_filter_by_marketable
(
self
,
status
,
is_marketable
,
expected_query_count
):
""" Verify the endpoint filters programs to those that are marketable. """
""" Verify the endpoint filters programs to those that are marketable. """
url
=
self
.
list_path
+
'?marketable=1'
url
=
self
.
list_path
+
'?marketable=1'
ProgramFactory
(
marketing_slug
=
''
,
partner
=
self
.
partner
)
ProgramFactory
(
marketing_slug
=
''
)
programs
=
ProgramFactory
.
create_batch
(
3
,
status
=
status
,
partner
=
self
.
partner
)
programs
=
ProgramFactory
.
create_batch
(
3
,
status
=
status
)
programs
.
reverse
()
programs
.
reverse
()
expected
=
programs
if
is_marketable
else
[]
expected
=
programs
if
is_marketable
else
[]
...
@@ -217,11 +214,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -217,11 +214,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
test_filter_by_status
(
self
):
def
test_filter_by_status
(
self
):
""" Verify the endpoint allows programs to filtered by one, or more, statuses. """
""" Verify the endpoint allows programs to filtered by one, or more, statuses. """
active
=
ProgramFactory
(
status
=
ProgramStatus
.
Active
,
partner
=
self
.
partner
)
active
=
ProgramFactory
(
status
=
ProgramStatus
.
Active
)
retired
=
ProgramFactory
(
status
=
ProgramStatus
.
Retired
,
partner
=
self
.
partner
)
retired
=
ProgramFactory
(
status
=
ProgramStatus
.
Retired
)
url
=
self
.
list_path
+
'?status=active'
url
=
self
.
list_path
+
'?status=active'
self
.
assert_list_results
(
url
,
[
active
],
10
)
self
.
assert_list_results
(
url
,
[
active
],
8
)
url
=
self
.
list_path
+
'?status=retired'
url
=
self
.
list_path
+
'?status=retired'
self
.
assert_list_results
(
url
,
[
retired
],
8
)
self
.
assert_list_results
(
url
,
[
retired
],
8
)
...
@@ -231,11 +228,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -231,11 +228,11 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
def
test_filter_by_hidden
(
self
):
def
test_filter_by_hidden
(
self
):
""" Endpoint should filter programs by their hidden attribute value. """
""" Endpoint should filter programs by their hidden attribute value. """
hidden
=
ProgramFactory
(
hidden
=
True
,
partner
=
self
.
partner
)
hidden
=
ProgramFactory
(
hidden
=
True
)
not_hidden
=
ProgramFactory
(
hidden
=
False
,
partner
=
self
.
partner
)
not_hidden
=
ProgramFactory
(
hidden
=
False
)
url
=
self
.
list_path
+
'?hidden=True'
url
=
self
.
list_path
+
'?hidden=True'
self
.
assert_list_results
(
url
,
[
hidden
],
10
)
self
.
assert_list_results
(
url
,
[
hidden
],
8
)
url
=
self
.
list_path
+
'?hidden=False'
url
=
self
.
list_path
+
'?hidden=False'
self
.
assert_list_results
(
url
,
[
not_hidden
],
8
)
self
.
assert_list_results
(
url
,
[
not_hidden
],
8
)
...
@@ -250,7 +247,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
...
@@ -250,7 +247,7 @@ class ProgramViewSetTests(SerializationMixin, APITestCase):
""" Verify the endpoint returns marketing URLs without UTM parameters. """
""" Verify the endpoint returns marketing URLs without UTM parameters. """
url
=
self
.
list_path
+
'?exclude_utm=1'
url
=
self
.
list_path
+
'?exclude_utm=1'
program
=
self
.
create_program
()
program
=
self
.
create_program
()
self
.
assert_list_results
(
url
,
[
program
],
1
4
,
extra_context
=
{
'exclude_utm'
:
1
})
self
.
assert_list_results
(
url
,
[
program
],
1
2
,
extra_context
=
{
'exclude_utm'
:
1
})
def
test_minimal_serializer_use
(
self
):
def
test_minimal_serializer_use
(
self
):
""" Verify that the list view uses the minimal serializer. """
""" Verify that the list view uses the minimal serializer. """
...
...
course_discovery/apps/api/v1/tests/test_views/test_search.py
View file @
e79197f0
...
@@ -3,12 +3,13 @@ import json
...
@@ -3,12 +3,13 @@ import json
import
urllib.parse
import
urllib.parse
import
ddt
import
ddt
from
django.conf
import
settings
from
django.urls
import
reverse
from
django.urls
import
reverse
from
haystack.query
import
SearchQuerySet
from
haystack.query
import
SearchQuerySet
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.serializers
import
(
CourseRunSearchSerializer
,
ProgramSearchSerializer
,
from
course_discovery.apps.api.serializers
import
(
CourseRunSearchSerializer
,
ProgramSearchSerializer
,
TypeaheadCourseRunSearchSerializer
,
TypeaheadProgramSearchSerializer
)
TypeaheadCourseRunSearchSerializer
,
TypeaheadProgramSearchSerializer
)
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
from
course_discovery.apps.api.v1.views.search
import
TypeaheadSearchView
from
course_discovery.apps.api.v1.views.search
import
TypeaheadSearchView
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
PartnerFactory
,
UserFactory
from
course_discovery.apps.core.tests.factories
import
USER_PASSWORD
,
PartnerFactory
,
UserFactory
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
from
course_discovery.apps.core.tests.mixins
import
ElasticsearchTestMixin
...
@@ -87,8 +88,14 @@ class SynonymTestMixin:
...
@@ -87,8 +88,14 @@ class SynonymTestMixin:
self
.
assertDictEqual
(
response1
,
response2
)
self
.
assertDictEqual
(
response1
,
response2
)
class
DefaultPartnerMixin
:
def
setUp
(
self
):
super
(
DefaultPartnerMixin
,
self
)
.
setUp
()
self
.
partner
=
PartnerFactory
(
pk
=
settings
.
DEFAULT_PARTNER_ID
)
@ddt.ddt
@ddt.ddt
class
CourseRunSearchViewSetTests
(
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
class
CourseRunSearchViewSetTests
(
DefaultPartnerMixin
,
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
APITestCase
):
APITestCase
):
""" Tests for CourseRunSearchViewSet. """
""" Tests for CourseRunSearchViewSet. """
faceted_path
=
reverse
(
'api:v1:search-course_runs-facets'
)
faceted_path
=
reverse
(
'api:v1:search-course_runs-facets'
)
...
@@ -264,7 +271,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
...
@@ -264,7 +271,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
)
)
self
.
reindex_courses
(
program
)
self
.
reindex_courses
(
program
)
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
4
):
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -288,7 +295,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
...
@@ -288,7 +295,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
ProgramFactory
(
courses
=
[
course_run
.
course
],
status
=
program_status
)
ProgramFactory
(
courses
=
[
course_run
.
course
],
status
=
program_status
)
self
.
reindex_courses
(
active_program
)
self
.
reindex_courses
(
active_program
)
with
self
.
assertNumQueries
(
6
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
response
=
self
.
get_response
(
'software'
,
faceted
=
False
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
...
@@ -306,7 +313,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
...
@@ -306,7 +313,7 @@ class CourseRunSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
@ddt.ddt
@ddt.ddt
class
AggregateSearchViewSetTests
(
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
class
AggregateSearchViewSetTests
(
DefaultPartnerMixin
,
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
SynonymTestMixin
,
APITestCase
):
path
=
reverse
(
'api:v1:search-all-facets'
)
path
=
reverse
(
'api:v1:search-all-facets'
)
...
@@ -431,7 +438,7 @@ class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
...
@@ -431,7 +438,7 @@ class AggregateSearchViewSetTests(SerializationMixin, LoginMixin, ElasticsearchT
assert
expected
==
actual
assert
expected
==
actual
class
TypeaheadSearchViewTests
(
TypeaheadSerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
class
TypeaheadSearchViewTests
(
DefaultPartnerMixin
,
TypeaheadSerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
SynonymTestMixin
,
APITestCase
):
path
=
reverse
(
'api:v1:search-typeahead'
)
path
=
reverse
(
'api:v1:search-typeahead'
)
...
@@ -613,3 +620,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
...
@@ -613,3 +620,23 @@ class TypeaheadSearchViewTests(TypeaheadSerializationMixin, LoginMixin, Elastics
self
.
serialize_program
(
harvard_program
)]
self
.
serialize_program
(
harvard_program
)]
}
}
self
.
assertDictEqual
(
response
.
data
,
expected
)
self
.
assertDictEqual
(
response
.
data
,
expected
)
def
test_typeahead_partner_filter
(
self
):
""" Ensure that a partner param limits results to that partner. """
course_runs
=
[]
programs
=
[]
for
partner
in
[
'edx'
,
'other'
]:
title
=
'Belongs to partner '
+
partner
partner
=
PartnerFactory
(
short_code
=
partner
)
course_runs
.
append
(
CourseRunFactory
(
title
=
title
,
course
=
CourseFactory
(
partner
=
partner
)))
programs
.
append
(
ProgramFactory
(
title
=
title
,
partner
=
partner
,
status
=
ProgramStatus
.
Active
))
response
=
self
.
get_response
({
'q'
:
'partner'
},
'edx'
)
self
.
assertEqual
(
response
.
status_code
,
200
)
edx_course_run
=
course_runs
[
0
]
edx_program
=
programs
[
0
]
self
.
assertDictEqual
(
response
.
data
,
{
'course_runs'
:
[
self
.
serialize_course_run
(
edx_course_run
)],
'programs'
:
[
self
.
serialize_program
(
edx_program
)]})
course_discovery/apps/api/v1/views/__init__.py
View file @
e79197f0
...
@@ -46,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset):
...
@@ -46,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset):
queryset
=
queryset
.
select_related
(
*
_select_related_fields
[
'course'
])
queryset
=
queryset
.
select_related
(
*
_select_related_fields
[
'course'
])
queryset
=
queryset
.
prefetch_related
(
*
_prefetch_fields
[
'course'
])
queryset
=
queryset
.
prefetch_related
(
*
_prefetch_fields
[
'course'
])
return
queryset
return
queryset
class
PartnerMixin
:
def
get_partner
(
self
):
""" Return the partner for the short_code passed in or the default partner """
partner_code
=
self
.
request
.
query_params
.
get
(
'partner'
)
if
partner_code
:
try
:
partner
=
Partner
.
objects
.
get
(
short_code
=
partner_code
)
except
Partner
.
DoesNotExist
:
raise
InvalidPartnerError
(
'Unknown Partner: {}'
.
format
(
partner_code
))
else
:
partner
=
Partner
.
objects
.
get
(
id
=
settings
.
DEFAULT_PARTNER_ID
)
return
partner
course_discovery/apps/api/v1/views/course_runs.py
View file @
e79197f0
...
@@ -7,14 +7,14 @@ from rest_framework.response import Response
...
@@ -7,14 +7,14 @@ from rest_framework.response import Response
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api.pagination
import
ProxiedPagination
from
course_discovery.apps.api.pagination
import
ProxiedPagination
from
course_discovery.apps.api.v1.views
import
get_query_param
from
course_discovery.apps.api.v1.views
import
PartnerMixin
,
get_query_param
from
course_discovery.apps.core.utils
import
SearchQuerySetWrapper
from
course_discovery.apps.core.utils
import
SearchQuerySetWrapper
from
course_discovery.apps.course_metadata.constants
import
COURSE_RUN_ID_REGEX
from
course_discovery.apps.course_metadata.constants
import
COURSE_RUN_ID_REGEX
from
course_discovery.apps.course_metadata.models
import
CourseRun
from
course_discovery.apps.course_metadata.models
import
CourseRun
# pylint: disable=no-member
# pylint: disable=no-member
class
CourseRunViewSet
(
viewsets
.
ModelViewSet
):
class
CourseRunViewSet
(
PartnerMixin
,
viewsets
.
ModelViewSet
):
""" CourseRun resource. """
""" CourseRun resource. """
filter_backends
=
(
DjangoFilterBackend
,
OrderingFilter
)
filter_backends
=
(
DjangoFilterBackend
,
OrderingFilter
)
filter_class
=
filters
.
CourseRunFilter
filter_class
=
filters
.
CourseRunFilter
...
@@ -41,7 +41,7 @@ class CourseRunViewSet(viewsets.ModelViewSet):
...
@@ -41,7 +41,7 @@ class CourseRunViewSet(viewsets.ModelViewSet):
multiple: false
multiple: false
"""
"""
q
=
self
.
request
.
query_params
.
get
(
'q'
)
q
=
self
.
request
.
query_params
.
get
(
'q'
)
partner
=
self
.
request
.
site
.
partner
partner
=
self
.
get_partner
()
if
q
:
if
q
:
qs
=
SearchQuerySetWrapper
(
CourseRun
.
search
(
q
)
.
filter
(
partner
=
partner
.
short_code
))
qs
=
SearchQuerySetWrapper
(
CourseRun
.
search
(
q
)
.
filter
(
partner
=
partner
.
short_code
))
...
@@ -78,6 +78,12 @@ class CourseRunViewSet(viewsets.ModelViewSet):
...
@@ -78,6 +78,12 @@ class CourseRunViewSet(viewsets.ModelViewSet):
type: string
type: string
paramType: query
paramType: query
multiple: false
multiple: false
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
- name: hidden
- name: hidden
description: Filter based on wether the course run is hidden from search.
description: Filter based on wether the course run is hidden from search.
required: false
required: false
...
@@ -158,7 +164,7 @@ class CourseRunViewSet(viewsets.ModelViewSet):
...
@@ -158,7 +164,7 @@ class CourseRunViewSet(viewsets.ModelViewSet):
"""
"""
query
=
request
.
GET
.
get
(
'query'
)
query
=
request
.
GET
.
get
(
'query'
)
course_run_ids
=
request
.
GET
.
get
(
'course_run_ids'
)
course_run_ids
=
request
.
GET
.
get
(
'course_run_ids'
)
partner
=
self
.
request
.
site
.
partner
partner
=
self
.
get_partner
()
if
query
and
course_run_ids
:
if
query
and
course_run_ids
:
course_run_ids
=
course_run_ids
.
split
(
','
)
course_run_ids
=
course_run_ids
.
split
(
','
)
...
...
course_discovery/apps/api/v1/views/courses.py
View file @
e79197f0
...
@@ -18,6 +18,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -18,6 +18,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
filter_class
=
filters
.
CourseFilter
filter_class
=
filters
.
CourseFilter
lookup_field
=
'key'
lookup_field
=
'key'
lookup_value_regex
=
COURSE_ID_REGEX
lookup_value_regex
=
COURSE_ID_REGEX
queryset
=
Course
.
objects
.
all
()
permission_classes
=
(
IsAuthenticated
,)
permission_classes
=
(
IsAuthenticated
,)
serializer_class
=
serializers
.
CourseWithProgramsSerializer
serializer_class
=
serializers
.
CourseWithProgramsSerializer
...
@@ -26,17 +27,16 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -26,17 +27,16 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
pagination_class
=
ProxiedPagination
pagination_class
=
ProxiedPagination
def
get_queryset
(
self
):
def
get_queryset
(
self
):
partner
=
self
.
request
.
site
.
partner
q
=
self
.
request
.
query_params
.
get
(
'q'
)
q
=
self
.
request
.
query_params
.
get
(
'q'
)
if
q
:
if
q
:
queryset
=
Course
.
search
(
q
)
queryset
=
Course
.
search
(
q
)
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
queryset
,
partner
=
partner
)
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
queryset
)
else
:
else
:
if
get_query_param
(
self
.
request
,
'include_hidden_course_runs'
):
if
get_query_param
(
self
.
request
,
'include_hidden_course_runs'
):
course_runs
=
CourseRun
.
objects
.
filter
(
course__partner
=
partner
)
course_runs
=
CourseRun
.
objects
.
all
(
)
else
:
else
:
course_runs
=
CourseRun
.
objects
.
filter
(
course__partner
=
partner
)
.
exclude
(
hidden
=
True
)
course_runs
=
CourseRun
.
objects
.
exclude
(
hidden
=
True
)
if
get_query_param
(
self
.
request
,
'marketable_course_runs_only'
):
if
get_query_param
(
self
.
request
,
'marketable_course_runs_only'
):
course_runs
=
course_runs
.
marketable
()
.
active
()
course_runs
=
course_runs
.
marketable
()
.
active
()
...
@@ -49,8 +49,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -49,8 +49,7 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
self
.
get_serializer_class
()
.
prefetch_queryset
(
queryset
=
self
.
queryset
,
queryset
=
self
.
queryset
,
course_runs
=
course_runs
,
course_runs
=
course_runs
partner
=
partner
)
)
return
queryset
.
order_by
(
Lower
(
'key'
))
return
queryset
.
order_by
(
Lower
(
'key'
))
...
...
course_discovery/apps/api/v1/views/organizations.py
View file @
e79197f0
...
@@ -15,16 +15,13 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
...
@@ -15,16 +15,13 @@ class OrganizationViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field
=
'uuid'
lookup_field
=
'uuid'
lookup_value_regex
=
'[0-9a-f-]+'
lookup_value_regex
=
'[0-9a-f-]+'
permission_classes
=
(
IsAuthenticated
,)
permission_classes
=
(
IsAuthenticated
,)
queryset
=
serializers
.
OrganizationSerializer
.
prefetch_queryset
()
serializer_class
=
serializers
.
OrganizationSerializer
serializer_class
=
serializers
.
OrganizationSerializer
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# Explicitly support PageNumberPagination and LimitOffsetPagination. Future
# versions of this API should only support the system default, PageNumberPagination.
# versions of this API should only support the system default, PageNumberPagination.
pagination_class
=
ProxiedPagination
pagination_class
=
ProxiedPagination
def
get_queryset
(
self
):
partner
=
self
.
request
.
site
.
partner
return
serializers
.
OrganizationSerializer
.
prefetch_queryset
(
partner
=
partner
)
def
list
(
self
,
request
,
*
args
,
**
kwargs
):
def
list
(
self
,
request
,
*
args
,
**
kwargs
):
""" Retrieve a list of all organizations. """
""" Retrieve a list of all organizations. """
return
super
(
OrganizationViewSet
,
self
)
.
list
(
request
,
*
args
,
**
kwargs
)
return
super
(
OrganizationViewSet
,
self
)
.
list
(
request
,
*
args
,
**
kwargs
)
...
...
course_discovery/apps/api/v1/views/people.py
View file @
e79197f0
...
@@ -7,6 +7,7 @@ from rest_framework.response import Response
...
@@ -7,6 +7,7 @@ from rest_framework.response import Response
from
course_discovery.apps.api
import
serializers
from
course_discovery.apps.api
import
serializers
from
course_discovery.apps.api.pagination
import
PageNumberPagination
from
course_discovery.apps.api.pagination
import
PageNumberPagination
from
course_discovery.apps.api.v1.views
import
PartnerMixin
from
course_discovery.apps.course_metadata.exceptions
import
MarketingSiteAPIClientException
,
PersonToMarketingException
from
course_discovery.apps.course_metadata.exceptions
import
MarketingSiteAPIClientException
,
PersonToMarketingException
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
from
course_discovery.apps.course_metadata.people
import
MarketingSitePeople
...
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
...
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=no-member
# pylint: disable=no-member
class
PersonViewSet
(
viewsets
.
ModelViewSet
):
class
PersonViewSet
(
PartnerMixin
,
viewsets
.
ModelViewSet
):
""" PersonSerializer resource. """
""" PersonSerializer resource. """
lookup_field
=
'uuid'
lookup_field
=
'uuid'
...
@@ -29,7 +30,7 @@ class PersonViewSet(viewsets.ModelViewSet):
...
@@ -29,7 +30,7 @@ class PersonViewSet(viewsets.ModelViewSet):
""" Create a new person. """
""" Create a new person. """
person_data
=
request
.
data
person_data
=
request
.
data
partner
=
request
.
site
.
partner
partner
=
self
.
get_partner
()
person_data
[
'partner'
]
=
partner
.
id
person_data
[
'partner'
]
=
partner
.
id
serializer
=
self
.
get_serializer
(
data
=
person_data
)
serializer
=
self
.
get_serializer
(
data
=
person_data
)
serializer
.
is_valid
(
raise_exception
=
True
)
serializer
.
is_valid
(
raise_exception
=
True
)
...
...
course_discovery/apps/api/v1/views/programs.py
View file @
e79197f0
...
@@ -32,8 +32,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
...
@@ -32,8 +32,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
def
get_queryset
(
self
):
def
get_queryset
(
self
):
# This method prevents prefetches on the program queryset from "stacking,"
# This method prevents prefetches on the program queryset from "stacking,"
# which happens when the queryset is stored in a class property.
# which happens when the queryset is stored in a class property.
partner
=
self
.
request
.
site
.
partner
return
self
.
get_serializer_class
()
.
prefetch_queryset
()
return
self
.
get_serializer_class
()
.
prefetch_queryset
(
partner
)
def
get_serializer_context
(
self
,
*
args
,
**
kwargs
):
def
get_serializer_context
(
self
,
*
args
,
**
kwargs
):
context
=
super
()
.
get_serializer_context
(
*
args
,
**
kwargs
)
context
=
super
()
.
get_serializer_context
(
*
args
,
**
kwargs
)
...
@@ -90,7 +89,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
...
@@ -90,7 +89,7 @@ class ProgramViewSet(CacheResponseMixin, viewsets.ReadOnlyModelViewSet):
if
get_query_param
(
self
.
request
,
'uuids_only'
):
if
get_query_param
(
self
.
request
,
'uuids_only'
):
# DRF serializers don't have good support for simple, flat
# DRF serializers don't have good support for simple, flat
# representations like the one we want here.
# representations like the one we want here.
queryset
=
self
.
filter_queryset
(
Program
.
objects
.
filter
(
partner
=
self
.
request
.
site
.
partner
))
queryset
=
self
.
filter_queryset
(
Program
.
objects
.
all
(
))
uuids
=
queryset
.
values_list
(
'uuid'
,
flat
=
True
)
uuids
=
queryset
.
values_list
(
'uuid'
,
flat
=
True
)
return
Response
(
uuids
)
return
Response
(
uuids
)
...
...
course_discovery/apps/api/v1/views/search.py
View file @
e79197f0
...
@@ -12,6 +12,7 @@ from rest_framework.response import Response
...
@@ -12,6 +12,7 @@ from rest_framework.response import Response
from
rest_framework.views
import
APIView
from
rest_framework.views
import
APIView
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api
import
filters
,
serializers
from
course_discovery.apps.api.v1.views
import
PartnerMixin
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
ProgramStatus
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Program
from
course_discovery.apps.course_metadata.models
import
Course
,
CourseRun
,
Program
...
@@ -118,7 +119,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
...
@@ -118,7 +119,7 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class
=
serializers
.
AggregateSearchSerializer
serializer_class
=
serializers
.
AggregateSearchSerializer
class
TypeaheadSearchView
(
APIView
):
class
TypeaheadSearchView
(
PartnerMixin
,
APIView
):
""" Typeahead for courses and programs. """
""" Typeahead for courses and programs. """
RESULT_COUNT
=
3
RESULT_COUNT
=
3
permission_classes
=
(
IsAuthenticated
,)
permission_classes
=
(
IsAuthenticated
,)
...
@@ -180,7 +181,7 @@ class TypeaheadSearchView(APIView):
...
@@ -180,7 +181,7 @@ class TypeaheadSearchView(APIView):
type: string
type: string
"""
"""
query
=
request
.
query_params
.
get
(
'q'
)
query
=
request
.
query_params
.
get
(
'q'
)
partner
=
request
.
site
.
partner
partner
=
self
.
get_partner
()
if
not
query
:
if
not
query
:
raise
ValidationError
(
"The 'q' querystring parameter is required for searching."
)
raise
ValidationError
(
"The 'q' querystring parameter is required for searching."
)
course_runs
,
programs
=
self
.
get_results
(
query
,
partner
)
course_runs
,
programs
=
self
.
get_results
(
query
,
partner
)
...
...
course_discovery/apps/core/tests/test_views.py
View file @
e79197f0
...
@@ -10,7 +10,6 @@ from django.urls import reverse
...
@@ -10,7 +10,6 @@ from django.urls import reverse
from
django.utils.encoding
import
force_text
from
django.utils.encoding
import
force_text
from
course_discovery.apps.core.constants
import
Status
from
course_discovery.apps.core.constants
import
Status
from
course_discovery.apps.core.views
import
get_database_status
User
=
get_user_model
()
User
=
get_user_model
()
...
@@ -18,24 +17,13 @@ User = get_user_model()
...
@@ -18,24 +17,13 @@ User = get_user_model()
class
HealthTests
(
TestCase
):
class
HealthTests
(
TestCase
):
"""Tests of the health endpoint."""
"""Tests of the health endpoint."""
def
test_getting_database_ok_status
(
self
):
"""Method should return the OK status."""
status
=
get_database_status
()
self
.
assertEqual
(
status
,
Status
.
OK
)
def
test_getting_database_unavailable_status
(
self
):
"""Method should return the unavailable status when a DatabaseError occurs."""
with
mock
.
patch
(
'django.db.backends.base.base.BaseDatabaseWrapper.cursor'
,
side_effect
=
DatabaseError
):
status
=
get_database_status
()
self
.
assertEqual
(
status
,
Status
.
UNAVAILABLE
)
def
test_all_services_available
(
self
):
def
test_all_services_available
(
self
):
"""Test that the endpoint reports when all services are healthy."""
"""Test that the endpoint reports when all services are healthy."""
self
.
_assert_health
(
200
,
Status
.
OK
,
Status
.
OK
)
self
.
_assert_health
(
200
,
Status
.
OK
,
Status
.
OK
)
def
test_database_outage
(
self
):
def
test_database_outage
(
self
):
"""Test that the endpoint reports when the database is unavailable."""
"""Test that the endpoint reports when the database is unavailable."""
with
mock
.
patch
(
'
course_discovery.apps.core.views.get_database_status'
,
return_value
=
Status
.
UNAVAILABLE
):
with
mock
.
patch
(
'
django.db.backends.base.base.BaseDatabaseWrapper.cursor'
,
side_effect
=
DatabaseError
):
self
.
_assert_health
(
503
,
Status
.
UNAVAILABLE
,
Status
.
UNAVAILABLE
)
self
.
_assert_health
(
503
,
Status
.
UNAVAILABLE
,
Status
.
UNAVAILABLE
)
def
_assert_health
(
self
,
status_code
,
overall_status
,
database_status
):
def
_assert_health
(
self
,
status_code
,
overall_status
,
database_status
):
...
...
course_discovery/apps/core/views.py
View file @
e79197f0
...
@@ -15,18 +15,6 @@ logger = logging.getLogger(__name__)
...
@@ -15,18 +15,6 @@ logger = logging.getLogger(__name__)
User
=
get_user_model
()
User
=
get_user_model
()
def
get_database_status
():
"""Run a database query to see if the database is responsive."""
try
:
cursor
=
connection
.
cursor
()
cursor
.
execute
(
"SELECT 1"
)
cursor
.
fetchone
()
cursor
.
close
()
return
Status
.
OK
except
DatabaseError
:
return
Status
.
UNAVAILABLE
@transaction.non_atomic_requests
@transaction.non_atomic_requests
def
health
(
_
):
def
health
(
_
):
"""Allows a load balancer to verify this service is up.
"""Allows a load balancer to verify this service is up.
...
@@ -44,7 +32,15 @@ def health(_):
...
@@ -44,7 +32,15 @@ def health(_):
>>> response.content
>>> response.content
'{"overall_status": "OK", "detailed_status": {"database_status": "OK"}}'
'{"overall_status": "OK", "detailed_status": {"database_status": "OK"}}'
"""
"""
database_status
=
get_database_status
()
try
:
cursor
=
connection
.
cursor
()
cursor
.
execute
(
"SELECT 1"
)
cursor
.
fetchone
()
cursor
.
close
()
database_status
=
Status
.
OK
except
DatabaseError
:
database_status
=
Status
.
UNAVAILABLE
overall_status
=
Status
.
OK
if
(
database_status
==
Status
.
OK
)
else
Status
.
UNAVAILABLE
overall_status
=
Status
.
OK
if
(
database_status
==
Status
.
OK
)
else
Status
.
UNAVAILABLE
...
...
course_discovery/apps/edx_catalog_extensions/api/v1/tests/test_views.py
View file @
e79197f0
...
@@ -2,17 +2,17 @@ import datetime
...
@@ -2,17 +2,17 @@ import datetime
import
urllib.parse
import
urllib.parse
from
django.urls
import
reverse
from
django.urls
import
reverse
from
rest_framework.test
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.mixins
import
APITestCase
from
course_discovery.apps.api.v1.tests.test_views.test_search
import
(
from
course_discovery.apps.api.v1.tests.test_views.test_search
import
(
ElasticsearchTestMixin
,
LoginMixin
,
SerializationMixin
,
SynonymTestMixin
DefaultPartnerMixin
,
ElasticsearchTestMixin
,
LoginMixin
,
SerializationMixin
,
SynonymTestMixin
)
)
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.choices
import
CourseRunStatus
,
ProgramStatus
from
course_discovery.apps.course_metadata.tests.factories
import
CourseFactory
,
CourseRunFactory
,
ProgramFactory
from
course_discovery.apps.course_metadata.tests.factories
import
CourseFactory
,
CourseRunFactory
,
ProgramFactory
from
course_discovery.apps.edx_catalog_extensions.api.serializers
import
DistinctCountsAggregateFacetSearchSerializer
from
course_discovery.apps.edx_catalog_extensions.api.serializers
import
DistinctCountsAggregateFacetSearchSerializer
class
DistinctCountsAggregateSearchViewSetTests
(
SerializationMixin
,
LoginMixin
,
class
DistinctCountsAggregateSearchViewSetTests
(
DefaultPartnerMixin
,
SerializationMixin
,
LoginMixin
,
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
ElasticsearchTestMixin
,
SynonymTestMixin
,
APITestCase
):
path
=
reverse
(
'extensions:api:v1:search-all-facets'
)
path
=
reverse
(
'extensions:api:v1:search-all-facets'
)
...
...
course_discovery/settings/base.py
View file @
e79197f0
...
@@ -79,7 +79,6 @@ MIDDLEWARE_CLASSES = (
...
@@ -79,7 +79,6 @@ MIDDLEWARE_CLASSES = (
'django.contrib.auth.middleware.AuthenticationMiddleware'
,
'django.contrib.auth.middleware.AuthenticationMiddleware'
,
'django.contrib.auth.middleware.SessionAuthenticationMiddleware'
,
'django.contrib.auth.middleware.SessionAuthenticationMiddleware'
,
'django.contrib.messages.middleware.MessageMiddleware'
,
'django.contrib.messages.middleware.MessageMiddleware'
,
'django.contrib.sites.middleware.CurrentSiteMiddleware'
,
'django.middleware.clickjacking.XFrameOptionsMiddleware'
,
'django.middleware.clickjacking.XFrameOptionsMiddleware'
,
'social_django.middleware.SocialAuthExceptionMiddleware'
,
'social_django.middleware.SocialAuthExceptionMiddleware'
,
'waffle.middleware.WaffleMiddleware'
,
'waffle.middleware.WaffleMiddleware'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment