Commit 83cef72b by Clinton Blackburn

Restored course search functionality

ECOM-4120
parent f2ec8b14
......@@ -4,8 +4,8 @@ from rest_framework.test import APITestCase
from course_discovery.apps.api.v1.tests.test_views.mixins import SerializationMixin
from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD
from course_discovery.apps.course_metadata.tests.factories import CourseFactory
from course_discovery.apps.course_metadata.models import Course
from course_discovery.apps.course_metadata.tests.factories import CourseFactory
class CourseViewSetTests(SerializationMixin, APITestCase):
......@@ -24,7 +24,7 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
self.assertEqual(response.data, self.serialize_course(self.course))
def test_list(self):
""" Verify the endpoint returns a list of all catalogs. """
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
response = self.client.get(url)
......@@ -33,3 +33,14 @@ class CourseViewSetTests(SerializationMixin, APITestCase):
response.data['results'],
self.serialize_course(Course.objects.all().order_by(Lower('key')), many=True)
)
def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """
title = 'Some random course'
courses = CourseFactory.create_batch(3, title=title)
courses = sorted(courses, key=lambda course: course.key.lower())
query = 'title:' + title
url = '{root}?q={query}'.format(root=reverse('api:v1:course-list'), query=query)
response = self.client.get(url)
self.assertListEqual(response.data['results'], self.serialize_course(courses, many=True))
......@@ -103,10 +103,15 @@ class CourseViewSet(viewsets.ReadOnlyModelViewSet):
""" Course resource. """
lookup_field = 'key'
lookup_value_regex = COURSE_ID_REGEX
queryset = Course.objects.all().order_by(Lower('key'))
queryset = Course.objects.all()
permission_classes = (IsAuthenticated,)
serializer_class = CourseSerializer
def get_queryset(self):
q = self.request.query_params.get('q', None)
queryset = Course.search(q) if q else super(CourseViewSet, self).get_queryset()
return queryset.order_by(Lower('key'))
# The boilerplate methods are required to be recognized by swagger
def list(self, request, *args, **kwargs):
""" List all courses. """
......
......@@ -29,9 +29,7 @@ class Catalog(ModelPermissionsMixin, TimeStampedModel):
Returns:
QuerySet
"""
results = self._get_query_results()
ids = [result.pk for result in results]
return Course.objects.filter(pk__in=ids)
return Course.search(self.query)
@property
def courses_count(self):
......
......@@ -5,6 +5,7 @@ import pytz
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django_extensions.db.models import TimeStampedModel
from haystack.query import SearchQuerySet
from simple_history.models import HistoricalRecords
from sortedm2m.fields import SortedManyToManyField
......@@ -152,6 +153,22 @@ class Course(TimeStampedModel):
"""
return self.course_runs.filter(enrollment_end__gt=datetime.datetime.now(pytz.UTC))
@classmethod
def search(cls, query):
""" Queries the search index.
Args:
query (str) -- Elasticsearch querystring (e.g. `title:intro*`)
Returns:
QuerySet
"""
# NOTE (CCB): Ensure the query is lowercase, since that is how we index our data.
query = query.lower()
results = SearchQuerySet().models(cls).raw_search(query)
ids = [result.pk for result in results]
return cls.objects.filter(pk__in=ids)
def __str__(self):
return '{key}: {title}'.format(key=self.key, title=self.title)
......
......@@ -5,7 +5,7 @@ import pytz
from django.test import TestCase
from course_discovery.apps.course_metadata.models import (
AbstractNamedModel, AbstractMediaModel, AbstractValueModel, CourseOrganization
AbstractNamedModel, AbstractMediaModel, AbstractValueModel, CourseOrganization, Course
)
from course_discovery.apps.course_metadata.tests import factories
......@@ -58,6 +58,15 @@ class CourseTests(TestCase):
active = factories.CourseRunFactory(course=self.course, enrollment_end=enrollment_end)
self.assertListEqual(list(self.course.active_course_runs), [active])
def test_search(self):
""" Verify the method returns a filtered queryset of courses. """
title = 'Some random course'
courses = factories.CourseFactory.create_batch(3, title=title)
courses = sorted(courses, key=lambda course: course.key)
query = 'title:' + title
actual = list(Course.search(query).order_by('key'))
self.assertEqual(actual, courses)
@ddt.ddt
class CourseRunTests(TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment