Commit db6a4123 by Clinton Blackburn Committed by GitHub

Added active and marketable filters to CourseRun and Program models (#287)

API consumers can now request CourseRuns that are active/marketable and Programs that are marketable.

ECOM-5416
parent c7b5e126
import logging
import django_filters import django_filters
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db.models import QuerySet from django.db.models import QuerySet
...@@ -10,6 +12,7 @@ from rest_framework.exceptions import PermissionDenied, NotFound ...@@ -10,6 +12,7 @@ from rest_framework.exceptions import PermissionDenied, NotFound
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
...@@ -66,6 +69,24 @@ class CharListFilter(django_filters.CharFilter): ...@@ -66,6 +69,24 @@ class CharListFilter(django_filters.CharFilter):
return super(CharListFilter, self).filter(qs, value) return super(CharListFilter, self).filter(qs, value)
class FilterSetMixin:
def _apply_filter(self, name, queryset, value):
try:
if int(value):
queryset = getattr(queryset, name)()
except ValueError:
logger.exception('The "%s" filter requires an integer value of either 0 or 1. %s is invalid', name, value)
raise
return queryset
def filter_active(self, queryset, value):
return self._apply_filter('active', queryset, value)
def filter_marketable(self, queryset, value):
return self._apply_filter('marketable', queryset, value)
class CourseFilter(django_filters.FilterSet): class CourseFilter(django_filters.FilterSet):
keys = CharListFilter(name='key', lookup_type='in') keys = CharListFilter(name='key', lookup_type='in')
...@@ -74,7 +95,9 @@ class CourseFilter(django_filters.FilterSet): ...@@ -74,7 +95,9 @@ class CourseFilter(django_filters.FilterSet):
fields = ['keys'] fields = ['keys']
class CourseRunFilter(django_filters.FilterSet): class CourseRunFilter(FilterSetMixin, django_filters.FilterSet):
active = django_filters.MethodFilter()
marketable = django_filters.MethodFilter()
keys = CharListFilter(name='key', lookup_type='in') keys = CharListFilter(name='key', lookup_type='in')
@property @property
...@@ -91,7 +114,8 @@ class CourseRunFilter(django_filters.FilterSet): ...@@ -91,7 +114,8 @@ class CourseRunFilter(django_filters.FilterSet):
fields = ['keys'] fields = ['keys']
class ProgramFilter(django_filters.FilterSet): class ProgramFilter(FilterSetMixin, django_filters.FilterSet):
marketable = django_filters.MethodFilter()
type = django_filters.CharFilter(name='type__name', lookup_expr='iexact') type = django_filters.CharFilter(name='type__name', lookup_expr='iexact')
uuids = CharListFilter(name='uuid', lookup_type='in') uuids = CharListFilter(name='uuid', lookup_type='in')
......
...@@ -281,7 +281,7 @@ class CourseRunWithProgramsSerializer(CourseRunSerializer): ...@@ -281,7 +281,7 @@ class CourseRunWithProgramsSerializer(CourseRunSerializer):
class Meta(CourseRunSerializer.Meta): class Meta(CourseRunSerializer.Meta):
model = CourseRun model = CourseRun
fields = CourseRunSerializer.Meta.fields + ('programs', ) fields = CourseRunSerializer.Meta.fields + ('programs',)
class ContainedCourseRunsSerializer(serializers.Serializer): class ContainedCourseRunsSerializer(serializers.Serializer):
...@@ -323,12 +323,15 @@ class CourseWithProgramsSerializer(CourseSerializer): ...@@ -323,12 +323,15 @@ class CourseWithProgramsSerializer(CourseSerializer):
class Meta(CourseSerializer.Meta): class Meta(CourseSerializer.Meta):
model = Course model = Course
fields = CourseSerializer.Meta.fields + ('programs', ) fields = CourseSerializer.Meta.fields + ('programs',)
class CourseSerializerExcludingClosedRuns(CourseSerializer): class CourseSerializerExcludingClosedRuns(CourseSerializer):
"""A ``CourseSerializer`` which only includes active course runs, as determined by ``CourseQuerySet``.""" """A ``CourseSerializer`` which only includes active course runs, as determined by ``CourseQuerySet``."""
course_runs = CourseRunSerializer(many=True, source='active_course_runs') course_runs = serializers.SerializerMethodField()
def get_course_runs(self, course):
return CourseRunSerializer(course.course_runs.active().marketable(), many=True, context=self.context).data
class ContainedCoursesSerializer(serializers.Serializer): class ContainedCoursesSerializer(serializers.Serializer):
......
# pylint: disable=no-member # pylint: disable=no-member
import datetime
import urllib import urllib
import ddt import ddt
import pytz
from django.db.models.functions import Lower from django.db.models.functions import Lower
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APITestCase, APIRequestFactory from rest_framework.test import APITestCase, APIRequestFactory
...@@ -19,9 +21,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -19,9 +21,9 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
super(CourseRunViewSetTests, self).setUp() super(CourseRunViewSetTests, self).setUp()
self.user = UserFactory(is_staff=True, is_superuser=True) self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user) self.client.force_authenticate(self.user)
self.default_partner = PartnerFactory() self.partner = PartnerFactory()
self.course_run = CourseRunFactory(course__partner=self.default_partner) self.course_run = CourseRunFactory(course__partner=self.partner)
self.course_run_2 = CourseRunFactory(course__partner=self.default_partner) self.course_run_2 = CourseRunFactory(course__partner=self.partner)
self.refresh_index() self.refresh_index()
self.request = APIRequestFactory().get('/') self.request = APIRequestFactory().get('/')
self.request.user = self.user self.request.user = self.user
...@@ -50,7 +52,7 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -50,7 +52,7 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
def test_list_query(self): def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """ """ Verify the endpoint returns a filtered list of courses """
course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.default_partner) course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner)
CourseRunFactory(title='non-matching name') CourseRunFactory(title='non-matching name')
query = 'title:Some random title' query = 'title:Some random title'
url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query) url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query)
...@@ -70,15 +72,53 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase): ...@@ -70,15 +72,53 @@ class CourseRunViewSetTests(ElasticsearchTestMixin, APITestCase):
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
def test_list_key_filter(self): def assert_list_results(self, url, expected):
expected = sorted(expected, key=lambda course_run: course_run.key.lower())
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(response.data['results'], self.serialize_course_run(expected, many=True))
def test_filter_by_keys(self):
""" Verify the endpoint returns a list of course runs filtered by the specified keys. """ """ Verify the endpoint returns a list of course runs filtered by the specified keys. """
course_runs = CourseRunFactory.create_batch(3, course__partner=self.default_partner) CourseRun.objects.all().delete()
course_runs = sorted(course_runs, key=lambda course: course.key.lower()) expected = CourseRunFactory.create_batch(3, course__partner=self.partner)
keys = ','.join([course.key for course in course_runs]) keys = ','.join([course.key for course in expected])
url = '{root}?keys={keys}'.format(root=reverse('api:v1:course_run-list'), keys=keys) url = '{root}?keys={keys}'.format(root=reverse('api:v1:course_run-list'), keys=keys)
self.assert_list_results(url, expected)
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course_run(course_runs, many=True)) def test_filter_by_marketable(self):
""" Verify the endpoint filters course runs to those that are marketable. """
CourseRun.objects.all().delete()
expected = CourseRunFactory.create_batch(3, course__partner=self.partner)
CourseRunFactory.create_batch(3, slug=None, course__partner=self.partner)
CourseRunFactory.create_batch(3, slug='', course__partner=self.partner)
url = reverse('api:v1:course_run-list') + '?marketable=1'
self.assert_list_results(url, expected)
def test_filter_by_active(self):
""" Verify the endpoint filters course runs to those that are active. """
CourseRun.objects.all().delete()
# Create course with end date in future and enrollment_end in past.
end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
enrollment_end = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=1)
CourseRunFactory(end=end, enrollment_end=enrollment_end, course__partner=self.partner)
# Create course with end date in past and no enrollment_end.
end = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=2)
CourseRunFactory(end=end, enrollment_end=None, course__partner=self.partner)
# Create course with end date in future and enrollment_end in future.
end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=1)
active_enrollment_end = CourseRunFactory(end=end, enrollment_end=enrollment_end, course__partner=self.partner)
# Create course with end date in future and no enrollment_end.
active_no_enrollment_end = CourseRunFactory(end=end, enrollment_end=None, course__partner=self.partner)
expected = [active_enrollment_end, active_no_enrollment_end]
url = reverse('api:v1:course_run-list') + '?active=1'
self.assert_list_results(url, expected)
def test_contains_single_course_run(self): def test_contains_single_course_run(self):
""" Verify that a single course_run is contained in a query """ """ Verify that a single course_run is contained in a query """
......
...@@ -4,7 +4,7 @@ from rest_framework.test import APITestCase, APIRequestFactory ...@@ -4,7 +4,7 @@ from rest_framework.test import APITestCase, APIRequestFactory
from course_discovery.apps.api.serializers import ProgramSerializer from course_discovery.apps.api.serializers import ProgramSerializer
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.course_metadata.models import Program from course_discovery.apps.course_metadata.models import Program
from course_discovery.apps.course_metadata.tests.factories import ProgramFactory, ProgramTypeFactory from course_discovery.apps.course_metadata.tests.factories import ProgramFactory
class ProgramViewSetTests(APITestCase): class ProgramViewSetTests(APITestCase):
...@@ -16,7 +16,6 @@ class ProgramViewSetTests(APITestCase): ...@@ -16,7 +16,6 @@ class ProgramViewSetTests(APITestCase):
self.client.login(username=self.user.username, password=USER_PASSWORD) self.client.login(username=self.user.username, password=USER_PASSWORD)
self.request = APIRequestFactory().get('/') self.request = APIRequestFactory().get('/')
self.request.user = self.user self.request.user = self.user
self.program = ProgramFactory()
def test_authentication(self): def test_authentication(self):
""" Verify the endpoint requires the user to be authenticated. """ """ Verify the endpoint requires the user to be authenticated. """
...@@ -29,51 +28,66 @@ class ProgramViewSetTests(APITestCase): ...@@ -29,51 +28,66 @@ class ProgramViewSetTests(APITestCase):
def test_get(self): def test_get(self):
""" Verify the endpoint returns the details for a single program. """ """ Verify the endpoint returns the details for a single program. """
url = reverse('api:v1:program-detail', kwargs={'uuid': self.program.uuid}) program = ProgramFactory()
url = reverse('api:v1:program-detail', kwargs={'uuid': program.uuid})
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, ProgramSerializer(self.program, context={'request': self.request}).data) self.assertEqual(response.data, ProgramSerializer(program, context={'request': self.request}).data)
def test_list(self): def assert_list_results(self, url, expected):
""" Verify the endpoint returns a list of all programs. """ """
ProgramFactory.create_batch(3) Asserts the results serialized/returned at the URL matches those that are expected.
Args:
response = self.client.get(self.list_path) url (str): URL from which data should be retrieved
self.assertEqual(response.status_code, 200) expected (list[Program]): Expected programs
Notes:
The API usually returns items in reverse order of creation (e.g. newest first). You may need to reverse
the values of `expected` if you encounter issues. This method will NOT do that reversal for you.
Returns:
None
"""
response = self.client.get(url)
self.assertEqual( self.assertEqual(
response.data['results'], response.data['results'],
ProgramSerializer(Program.objects.all(), many=True, context={'request': self.request}).data ProgramSerializer(expected, many=True, context={'request': self.request}).data
) )
def test_list(self):
""" Verify the endpoint returns a list of all programs. """
expected = ProgramFactory.create_batch(3)
expected.reverse()
self.assert_list_results(self.list_path, expected)
def test_filter_by_type(self): def test_filter_by_type(self):
""" Verify that the endpoint filters programs to those of a given type. """ """ Verify that the endpoint filters programs to those of a given type. """
url = self.list_path + '?type=' program_type_name = 'foo'
program = ProgramFactory(type__name=program_type_name)
url = self.list_path + '?type=' + program_type_name
self.assert_list_results(url, [program])
self.program.type = ProgramTypeFactory(name='Foo') url = self.list_path + '?type=bar'
self.program.save() # pylint: disable=no-member self.assert_list_results(url, [])
response = self.client.get(url + 'foo')
self.assertEqual(
response.data['results'][0],
ProgramSerializer(Program.objects.get(), context={'request': self.request}).data
)
response = self.client.get(url + 'bar')
self.assertEqual(response.data['results'], [])
def test_filter_by_uuids(self): def test_filter_by_uuids(self):
""" Verify that the endpoint filters programs to those matching the provided UUIDs. """ """ Verify that the endpoint filters programs to those matching the provided UUIDs. """
url = self.list_path + '?uuids=' expected = ProgramFactory.create_batch(2)
expected.reverse()
programs = [ProgramFactory(), self.program] uuids = [str(p.uuid) for p in expected]
uuids = [str(p.uuid) for p in programs] url = self.list_path + '?uuids=' + ','.join(uuids)
# Create a third program, which should be filtered out. # Create a third program, which should be filtered out.
ProgramFactory() ProgramFactory()
response = self.client.get(url + ','.join(uuids)) self.assert_list_results(url, expected)
self.assertEqual(
response.data['results'], def test_filter_by_marketable(self):
ProgramSerializer(programs, many=True, context={'request': self.request}).data """ Verify the endpoint filters programs to those that are marketable. """
) url = self.list_path + '?marketable=1'
ProgramFactory(marketing_slug='')
expected = ProgramFactory.create_batch(3)
expected.reverse()
self.assertEqual(list(Program.objects.marketable()), expected)
self.assert_list_results(url, expected)
...@@ -157,9 +157,7 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -157,9 +157,7 @@ class CatalogViewSet(viewsets.ModelViewSet):
course_runs = [] course_runs = []
for course in courses: for course in courses:
active_course_runs = course.active_course_runs course_runs += list(course.course_runs.active().marketable())
for acr in active_course_runs:
course_runs.append(acr)
serializer = serializers.FlattenedCourseRunWithCourseSerializer( serializer = serializers.FlattenedCourseRunWithCourseSerializer(
course_runs, many=True, context={'request': request} course_runs, many=True, context={'request': request}
...@@ -274,6 +272,20 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -274,6 +272,20 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
type: string type: string
paramType: query paramType: query
multiple: false multiple: false
- name: active
description: Retrieve active course runs. A course is considered active if its end date has not passed,
and it is open for enrollment.
required: false
type: integer
paramType: query
multiple: false
- name: marketable
description: Retrieve marketable course runs. A course run is considered marketable if it has a
marketing slug.
required: false
type: integer
paramType: query
multiple: false
""" """
return super(CourseRunViewSet, self).list(request, *args, **kwargs) return super(CourseRunViewSet, self).list(request, *args, **kwargs)
...@@ -336,6 +348,25 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -336,6 +348,25 @@ class ProgramViewSet(viewsets.ReadOnlyModelViewSet):
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend,)
filter_class = filters.ProgramFilter filter_class = filters.ProgramFilter
def list(self, request, *args, **kwargs):
""" List all programs.
---
parameters:
- name: partner
description: Filter by partner
required: false
type: string
paramType: query
multiple: false
- name: marketable
description: Retrieve marketable programs. A program is considered marketable if it has a marketing slug.
required: false
type: integer
paramType: query
multiple: false
"""
return super(ProgramViewSet, self).list(request, *args, **kwargs)
class ManagementViewSet(viewsets.ViewSet): class ManagementViewSet(viewsets.ViewSet):
permission_classes = (IsSuperuser,) permission_classes = (IsSuperuser,)
......
...@@ -68,7 +68,6 @@ class ProgramAdmin(admin.ModelAdmin): ...@@ -68,7 +68,6 @@ class ProgramAdmin(admin.ModelAdmin):
list_filter = ('partner', 'type',) list_filter = ('partner', 'type',)
ordering = ('uuid', 'title', 'status') ordering = ('uuid', 'title', 'status')
readonly_fields = ('uuid', 'custom_course_runs_display', 'excluded_course_runs',) readonly_fields = ('uuid', 'custom_course_runs_display', 'excluded_course_runs',)
search_fields = ('uuid', 'title', 'marketing_slug') search_fields = ('uuid', 'title', 'marketing_slug')
filter_horizontal = ('job_outlook_items', 'expected_learning_items',) filter_horizontal = ('job_outlook_items', 'expected_learning_items',)
......
...@@ -18,7 +18,7 @@ from stdimage.models import StdImageField ...@@ -18,7 +18,7 @@ from stdimage.models import StdImageField
from taggit.managers import TaggableManager from taggit.managers import TaggableManager
from course_discovery.apps.core.models import Currency, Partner from course_discovery.apps.core.models import Currency, Partner
from course_discovery.apps.course_metadata.query import CourseQuerySet from course_discovery.apps.course_metadata.query import CourseQuerySet, CourseRunQuerySet, ProgramQuerySet
from course_discovery.apps.course_metadata.utils import UploadToFieldNamePath from course_discovery.apps.course_metadata.utils import UploadToFieldNamePath
from course_discovery.apps.course_metadata.utils import clean_query from course_discovery.apps.course_metadata.utils import clean_query
from course_discovery.apps.ietf_language_tags.models import LanguageTag from course_discovery.apps.ietf_language_tags.models import LanguageTag
...@@ -358,6 +358,7 @@ class CourseRun(TimeStampedModel): ...@@ -358,6 +358,7 @@ class CourseRun(TimeStampedModel):
slug = models.CharField(max_length=255, blank=True, null=True, db_index=True) slug = models.CharField(max_length=255, blank=True, null=True, db_index=True)
history = HistoricalRecords() history = HistoricalRecords()
objects = CourseRunQuerySet.as_manager()
@property @property
def marketing_url(self): def marketing_url(self):
...@@ -602,7 +603,8 @@ class Program(TimeStampedModel): ...@@ -602,7 +603,8 @@ class Program(TimeStampedModel):
'large': (1440, 480), 'large': (1440, 480),
'medium': (726, 242), 'medium': (726, 242),
'small': (435, 145), 'small': (435, 145),
'x-small': (348, 116)} 'x-small': (348, 116),
}
) )
banner_image_url = models.URLField(null=True, blank=True, help_text=_('Image used atop detail pages')) banner_image_url = models.URLField(null=True, blank=True, help_text=_('Image used atop detail pages'))
card_image_url = models.URLField(null=True, blank=True, help_text=_('Image used for discovery cards')) card_image_url = models.URLField(null=True, blank=True, help_text=_('Image used for discovery cards'))
...@@ -621,6 +623,8 @@ class Program(TimeStampedModel): ...@@ -621,6 +623,8 @@ class Program(TimeStampedModel):
blank=True, null=True blank=True, null=True
) )
objects = ProgramQuerySet.as_manager()
def __str__(self): def __str__(self):
return self.title return self.title
......
...@@ -17,3 +17,47 @@ class CourseQuerySet(models.QuerySet): ...@@ -17,3 +17,47 @@ class CourseQuerySet(models.QuerySet):
Q(course_runs__enrollment_end__isnull=True) Q(course_runs__enrollment_end__isnull=True)
) )
) )
class CourseRunQuerySet(models.QuerySet):
def active(self):
""" Returns CourseRuns that have not yet ended and meet the following enrollment criteria:
- Open for enrollment
- OR will be open for enrollment in the future
- OR have no specified enrollment close date (e.g. self-paced courses)
Returns:
QuerySet
"""
now = datetime.datetime.now(pytz.UTC)
return self.filter(
Q(end__gt=now) &
(
Q(enrollment_end__gt=now) |
Q(enrollment_end__isnull=True)
)
)
def marketable(self):
""" Returns CourseRuns that can be marketed to learners.
A CourseRun is considered marketable if it has a defined slug.
Returns:
QuerySet
"""
return self.exclude(slug__isnull=True).exclude(slug='')
class ProgramQuerySet(models.QuerySet):
def marketable(self):
""" Returns Programs that can be marketed to learners.
A Program is considered marketable if it has a defined marketing slug.
Returns:
QuerySet
"""
return self.exclude(marketing_slug__isnull=True).exclude(marketing_slug='')
import datetime
import itertools import itertools
from decimal import Decimal from decimal import Decimal
import ddt import ddt
import pytz
from dateutil.parser import parse from dateutil.parser import parse
from django.conf import settings from django.conf import settings
from django.db import IntegrityError from django.db import IntegrityError
...@@ -37,31 +35,6 @@ class CourseTests(TestCase): ...@@ -37,31 +35,6 @@ class CourseTests(TestCase):
""" Verify casting an instance to a string returns a string containing the key and title. """ """ Verify casting an instance to a string returns a string containing the key and title. """
self.assertEqual(str(self.course), '{key}: {title}'.format(key=self.course.key, title=self.course.title)) self.assertEqual(str(self.course), '{key}: {title}'.format(key=self.course.key, title=self.course.title))
def test_active_course_runs(self):
""" Verify the property returns only course runs currently open for enrollment or opening in the future. """
self.assertListEqual(list(self.course.active_course_runs), [])
# Create course with end date in future and enrollment_end in past.
end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
enrollment_end = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=1)
factories.CourseRunFactory(course=self.course, end=end, enrollment_end=enrollment_end)
# Create course with end date in past and no enrollment_end.
end = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=2)
factories.CourseRunFactory(course=self.course, end=end, enrollment_end=None)
self.assertListEqual(list(self.course.active_course_runs), [])
# Create course with end date in future and enrollment_end in future.
end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=1)
active_enrollment_end = factories.CourseRunFactory(course=self.course, end=end, enrollment_end=enrollment_end)
# Create course with end date in future and no enrollment_end.
active_no_enrollment_end = factories.CourseRunFactory(course=self.course, end=end, enrollment_end=None)
self.assertEqual(set(self.course.active_course_runs), {active_enrollment_end, active_no_enrollment_end})
def test_search(self): def test_search(self):
""" Verify the method returns a filtered queryset of courses. """ """ Verify the method returns a filtered queryset of courses. """
title = 'Some random title' title = 'Some random title'
......
import datetime import datetime
import ddt
import pytz import pytz
from django.test import TestCase from django.test import TestCase
from course_discovery.apps.course_metadata.models import Course from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, ProgramFactory
class CourseQuerySetTests(TestCase): class CourseQuerySetTests(TestCase):
def test_active(self): def test_active(self):
""" Verify the method filters the Courses to those with active course runs. """ """ Verify the method filters the Courses to those with active course runs. """
now = datetime.datetime.now(pytz.UTC) now = datetime.datetime.now(pytz.UTC)
...@@ -33,3 +33,52 @@ class CourseQuerySetTests(TestCase): ...@@ -33,3 +33,52 @@ class CourseQuerySetTests(TestCase):
CourseRunFactory(enrollment_end=None, end=inactive_course_end) CourseRunFactory(enrollment_end=None, end=inactive_course_end)
self.assertEqual(set(Course.objects.active()), {active_course, course_without_end}) self.assertEqual(set(Course.objects.active()), {active_course, course_without_end})
@ddt.ddt
class CourseRunQuerySetTests(TestCase):
def test_active(self):
""" Verify the method returns only course runs currently open for enrollment or opening in the future. """
# Create course with end date in future and enrollment_end in past.
end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
enrollment_end = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=1)
CourseRunFactory(end=end, enrollment_end=enrollment_end)
# Create course with end date in past and no enrollment_end.
end = datetime.datetime.now(pytz.UTC) - datetime.timedelta(days=2)
CourseRunFactory(end=end, enrollment_end=None)
self.assertEqual(CourseRun.objects.active().count(), 0)
# Create course with end date in future and enrollment_end in future.
end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=2)
enrollment_end = datetime.datetime.now(pytz.UTC) + datetime.timedelta(days=1)
active_enrollment_end = CourseRunFactory(end=end, enrollment_end=enrollment_end)
# Create course with end date in future and no enrollment_end.
active_no_enrollment_end = CourseRunFactory(end=end, enrollment_end=None)
self.assertEqual(set(CourseRun.objects.active()), {active_enrollment_end, active_no_enrollment_end})
def test_marketable(self):
""" Verify the method filters CourseRuns to those with slugs. """
course_run = CourseRunFactory()
self.assertEqual(list(CourseRun.objects.marketable()), [course_run])
@ddt.data(None, '')
def test_marketable_exclusions(self, slug):
""" Verify the method excludes CourseRuns without a slug. """
CourseRunFactory(slug=slug)
self.assertEqual(CourseRun.objects.marketable().count(), 0)
class ProgramQuerySetTests(TestCase):
def test_marketable(self):
""" Verify the method filters Programs to those with marketing slugs. """
program = ProgramFactory()
self.assertEqual(list(Program.objects.marketable()), [program])
def test_marketable_exclusions(self):
""" Verify the method excludes Programs without a marketing slug. """
ProgramFactory(marketing_slug='')
self.assertEqual(Program.objects.marketable().count(), 0)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment