Commit c5069d83 by Matthew Piatetsky

Filter typeahead results by partner

parent 58b5c4c2
import logging import logging
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from course_discovery.apps.api import serializers from course_discovery.apps.api import serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.utils import cast2int from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model() User = get_user_model()
...@@ -43,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset): ...@@ -43,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course']) queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course']) queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
from django.conf import settings
from django.db.models.functions import Lower from django.db.models.functions import Lower
from rest_framework import viewsets, status from rest_framework import viewsets, status
from rest_framework.decorators import list_route from rest_framework.decorators import list_route
...@@ -7,16 +6,14 @@ from rest_framework.permissions import IsAuthenticated ...@@ -7,16 +6,14 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError from course_discovery.apps.api.v1.views import get_query_param, PartnerMixin
from course_discovery.apps.api.v1.views import get_query_param
from course_discovery.apps.core.models import Partner
from course_discovery.apps.core.utils import SearchQuerySetWrapper from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member # pylint: disable=no-member
class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): class CourseRunViewSet(PartnerMixin, viewsets.ReadOnlyModelViewSet):
""" CourseRun resource. """ """ CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter) filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter filter_class = filters.CourseRunFilter
...@@ -27,19 +24,6 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -27,19 +24,6 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
queryset = CourseRun.objects.all().order_by(Lower('key')) queryset = CourseRun.objects.all().order_by(Lower('key'))
serializer_class = serializers.CourseRunWithProgramsSerializer serializer_class = serializers.CourseRunWithProgramsSerializer
def _get_partner(self):
""" Return the partner for the code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner')
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
def get_queryset(self): def get_queryset(self):
""" List one course run """ List one course run
--- ---
...@@ -52,7 +36,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -52,7 +36,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
multiple: false multiple: false
""" """
q = self.request.query_params.get('q', None) q = self.request.query_params.get('q', None)
partner = self._get_partner() partner = self.get_partner()
if q: if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code)) qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
...@@ -167,7 +151,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -167,7 +151,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
""" """
query = request.GET.get('query') query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids') course_run_ids = request.GET.get('course_run_ids')
partner = self._get_partner() partner = self.get_partner()
if query and course_run_ids: if query and course_run_ids:
course_run_ids = course_run_ids.split(',') course_run_ids = course_run_ids.split(',')
......
...@@ -5,7 +5,7 @@ from haystack.inputs import AutoQuery ...@@ -5,7 +5,7 @@ from haystack.inputs import AutoQuery
from haystack.query import SearchQuerySet from haystack.query import SearchQuerySet
from rest_framework import status from rest_framework import status
from rest_framework.decorators import list_route from rest_framework.decorators import list_route
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError, ValidationError
from rest_framework.filters import OrderingFilter from rest_framework.filters import OrderingFilter
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
...@@ -13,6 +13,7 @@ from rest_framework.views import APIView ...@@ -13,6 +13,7 @@ from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import PageNumberPagination from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
...@@ -123,12 +124,12 @@ class AggregateSearchViewSet(BaseHaystackViewSet): ...@@ -123,12 +124,12 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(APIView): class TypeaheadSearchView(PartnerMixin, APIView):
""" Typeahead for courses and programs. """ """ Typeahead for courses and programs. """
RESULT_COUNT = 3 RESULT_COUNT = 3
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
def get_results(self, query): def get_results(self, query, partner):
sqs = SearchQuerySet() sqs = SearchQuerySet()
clean_query = sqs.query.clean(query) clean_query = sqs.query.clean(query)
...@@ -137,14 +138,14 @@ class TypeaheadSearchView(APIView): ...@@ -137,14 +138,14 @@ class TypeaheadSearchView(APIView):
SQ(course_key=clean_query) | SQ(course_key=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query) SQ(authoring_organizations_autocomplete=clean_query)
) )
course_runs = course_runs.filter(published=True).exclude(hidden=True) course_runs = course_runs.filter(published=True).exclude(hidden=True).filter(partner=partner.short_code)
course_runs = course_runs[:self.RESULT_COUNT] course_runs = course_runs[:self.RESULT_COUNT]
programs = sqs.models(Program).filter( programs = sqs.models(Program).filter(
SQ(title_autocomplete=clean_query) | SQ(title_autocomplete=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query) SQ(authoring_organizations_autocomplete=clean_query)
) )
programs = programs.filter(status=ProgramStatus.Active) programs = programs.filter(status=ProgramStatus.Active).filter(partner=partner.short_code)
programs = programs[:self.RESULT_COUNT] programs = programs[:self.RESULT_COUNT]
return course_runs, programs return course_runs, programs
...@@ -165,11 +166,17 @@ class TypeaheadSearchView(APIView): ...@@ -165,11 +166,17 @@ class TypeaheadSearchView(APIView):
paramType: query paramType: query
required: true required: true
type: string type: string
- name: partner
description: "Partner short code"
paramType: query
required: false
type: string
""" """
query = request.query_params.get('q') query = request.query_params.get('q')
partner = self.get_partner()
if not query: if not query:
raise ParseError("The 'q' querystring parameter is required for searching.") raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query) course_runs, programs = self.get_results(query, partner)
data = {'course_runs': course_runs, 'programs': programs} data = {'course_runs': course_runs, 'programs': programs}
serializer = serializers.TypeaheadSearchSerializer(data) serializer = serializers.TypeaheadSearchSerializer(data)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment