Commit 903d013d by Ivan Ivic

[SOL-1975] Sort CourseRunViewSet queryset by start date

parent 5eecb677
...@@ -385,14 +385,16 @@ class CourseRunWithProgramsSerializer(CourseRunSerializer): ...@@ -385,14 +385,16 @@ class CourseRunWithProgramsSerializer(CourseRunSerializer):
programs = serializers.SerializerMethodField() programs = serializers.SerializerMethodField()
def get_programs(self, obj): def get_programs(self, obj):
programs = []
# Filter out non-deleted programs which this course_run is part of the program course_run exclusion # Filter out non-deleted programs which this course_run is part of the program course_run exclusion
programs = [program for program in obj.programs.all() if obj.programs:
if (self.context.get('include_deleted_programs') or programs = [program for program in obj.programs.all()
program.status != ProgramStatus.Deleted) and if (self.context.get('include_deleted_programs') or
obj.id not in (run.id for run in program.excluded_course_runs.all())] program.status != ProgramStatus.Deleted) and
# If flag is not set, remove programs from list that are unpublished obj.id not in (run.id for run in program.excluded_course_runs.all())]
if not self.context.get('include_unpublished_programs'): # If flag is not set, remove programs from list that are unpublished
programs = [program for program in programs if program.status != ProgramStatus.Unpublished] if not self.context.get('include_unpublished_programs'):
programs = [program for program in programs if program.status != ProgramStatus.Unpublished]
return NestedProgramSerializer(programs, many=True).data return NestedProgramSerializer(programs, many=True).data
......
...@@ -110,6 +110,19 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC ...@@ -110,6 +110,19 @@ class CourseRunViewSetTests(SerializationMixin, ElasticsearchTestMixin, APITestC
self.serialize_course_run(CourseRun.objects.all().order_by(Lower('key')), many=True) self.serialize_course_run(CourseRun.objects.all().order_by(Lower('key')), many=True)
) )
def test_list_sorted_by_course_start_date(self):
""" Verify the endpoint returns a list of all catalogs sorted by course start date. """
url = '{root}?ordering=start'.format(root=reverse('api:v1:course_run-list'))
with self.assertNumQueries(11):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertListEqual(
response.data['results'],
self.serialize_course_run(CourseRun.objects.all().order_by('start'), many=True)
)
def test_list_query(self): def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """ """ Verify the endpoint returns a filtered list of courses """
course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner) course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner)
......
...@@ -19,7 +19,7 @@ from haystack.query import SQ ...@@ -19,7 +19,7 @@ from haystack.query import SQ
from rest_framework import status, viewsets from rest_framework import status, viewsets
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.exceptions import PermissionDenied, ParseError from rest_framework.exceptions import PermissionDenied, ParseError
from rest_framework.filters import DjangoFilterBackend from rest_framework.filters import DjangoFilterBackend, OrderingFilter
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
...@@ -289,12 +289,13 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -289,12 +289,13 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
""" CourseRun resource. """ """ CourseRun resource. """
filter_backends = (DjangoFilterBackend,) filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter filter_class = filters.CourseRunFilter
lookup_field = 'key' lookup_field = 'key'
lookup_value_regex = COURSE_RUN_ID_REGEX lookup_value_regex = COURSE_RUN_ID_REGEX
queryset = CourseRun.objects.all().order_by(Lower('key')) ordering_fields = ('start',)
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
queryset = CourseRun.objects.all().order_by(Lower('key'))
serializer_class = serializers.CourseRunWithProgramsSerializer serializer_class = serializers.CourseRunWithProgramsSerializer
def _get_partner(self): def _get_partner(self):
......
...@@ -67,8 +67,13 @@ class SearchQuerySetWrapper(object): ...@@ -67,8 +67,13 @@ class SearchQuerySetWrapper(object):
def __init__(self, qs): def __init__(self, qs):
self.qs = qs self.qs = qs
def count(self): def __getattr__(self, item):
return self.qs.count() try:
super().__getattr__(item)
except AttributeError:
# If the attribute is not found on this class,
# proxy the request to the SearchQuerySet.
return getattr(self.qs, item)
def __iter__(self): def __iter__(self):
for result in self.qs: for result in self.qs:
......
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('course_metadata', '0033_courserun_mobile_available'),
]
operations = [
migrations.AlterField(
model_name='courserun',
name='start',
field=models.DateTimeField(blank=True, null=True, db_index=True),
),
]
...@@ -331,7 +331,7 @@ class CourseRun(TimeStampedModel): ...@@ -331,7 +331,7 @@ class CourseRun(TimeStampedModel):
max_length=255, default=None, null=True, blank=True, max_length=255, default=None, null=True, blank=True,
help_text=_( help_text=_(
"Title specific for this run of a course. Leave this value blank to default to the parent course's title.")) "Title specific for this run of a course. Leave this value blank to default to the parent course's title."))
start = models.DateTimeField(null=True, blank=True) start = models.DateTimeField(null=True, blank=True, db_index=True)
end = models.DateTimeField(null=True, blank=True, db_index=True) end = models.DateTimeField(null=True, blank=True, db_index=True)
enrollment_start = models.DateTimeField(null=True, blank=True) enrollment_start = models.DateTimeField(null=True, blank=True)
enrollment_end = models.DateTimeField(null=True, blank=True, db_index=True) enrollment_end = models.DateTimeField(null=True, blank=True, db_index=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment