Commit c5069d83 by Matthew Piatetsky

Filter typeahead results by partner

parent 58b5c4c2
import logging
from django.conf import settings
from django.contrib.auth import get_user_model
from course_discovery.apps.api import serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.utils import cast2int
from course_discovery.apps.core.models import Partner
logger = logging.getLogger(__name__)
User = get_user_model()
......@@ -43,3 +46,18 @@ def prefetch_related_objects_for_courses(queryset):
queryset = queryset.select_related(*_select_related_fields['course'])
queryset = queryset.prefetch_related(*_prefetch_fields['course'])
return queryset
class PartnerMixin:
def get_partner(self):
""" Return the partner for the short_code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner: {}'.format(partner_code))
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
from django.conf import settings
from django.db.models.functions import Lower
from rest_framework import viewsets, status
from rest_framework.decorators import list_route
......@@ -7,16 +6,14 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.exceptions import InvalidPartnerError
from course_discovery.apps.api.v1.views import get_query_param
from course_discovery.apps.core.models import Partner
from course_discovery.apps.api.v1.views import get_query_param, PartnerMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import CourseRun
# pylint: disable=no-member
class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
class CourseRunViewSet(PartnerMixin, viewsets.ReadOnlyModelViewSet):
""" CourseRun resource. """
filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_class = filters.CourseRunFilter
......@@ -27,19 +24,6 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
queryset = CourseRun.objects.all().order_by(Lower('key'))
serializer_class = serializers.CourseRunWithProgramsSerializer
def _get_partner(self):
""" Return the partner for the code passed in or the default partner """
partner_code = self.request.query_params.get('partner')
if partner_code:
try:
partner = Partner.objects.get(short_code=partner_code)
except Partner.DoesNotExist:
raise InvalidPartnerError('Unknown Partner')
else:
partner = Partner.objects.get(id=settings.DEFAULT_PARTNER_ID)
return partner
def get_queryset(self):
""" List one course run
---
......@@ -52,7 +36,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
multiple: false
"""
q = self.request.query_params.get('q', None)
partner = self._get_partner()
partner = self.get_partner()
if q:
qs = SearchQuerySetWrapper(CourseRun.search(q).filter(partner=partner.short_code))
......@@ -167,7 +151,7 @@ class CourseRunViewSet(viewsets.ReadOnlyModelViewSet):
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids')
partner = self._get_partner()
partner = self.get_partner()
if query and course_run_ids:
course_run_ids = course_run_ids.split(',')
......
......@@ -5,7 +5,7 @@ from haystack.inputs import AutoQuery
from haystack.query import SearchQuerySet
from rest_framework import status
from rest_framework.decorators import list_route
from rest_framework.exceptions import ParseError
from rest_framework.exceptions import ParseError, ValidationError
from rest_framework.filters import OrderingFilter
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
......@@ -13,6 +13,7 @@ from rest_framework.views import APIView
from course_discovery.apps.api import filters, serializers
from course_discovery.apps.api.pagination import PageNumberPagination
from course_discovery.apps.api.v1.views import PartnerMixin
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.models import Course, CourseRun, Program
......@@ -123,12 +124,12 @@ class AggregateSearchViewSet(BaseHaystackViewSet):
serializer_class = serializers.AggregateSearchSerializer
class TypeaheadSearchView(APIView):
class TypeaheadSearchView(PartnerMixin, APIView):
""" Typeahead for courses and programs. """
RESULT_COUNT = 3
permission_classes = (IsAuthenticated,)
def get_results(self, query):
def get_results(self, query, partner):
sqs = SearchQuerySet()
clean_query = sqs.query.clean(query)
......@@ -137,14 +138,14 @@ class TypeaheadSearchView(APIView):
SQ(course_key=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query)
)
course_runs = course_runs.filter(published=True).exclude(hidden=True)
course_runs = course_runs.filter(published=True).exclude(hidden=True).filter(partner=partner.short_code)
course_runs = course_runs[:self.RESULT_COUNT]
programs = sqs.models(Program).filter(
SQ(title_autocomplete=clean_query) |
SQ(authoring_organizations_autocomplete=clean_query)
)
programs = programs.filter(status=ProgramStatus.Active)
programs = programs.filter(status=ProgramStatus.Active).filter(partner=partner.short_code)
programs = programs[:self.RESULT_COUNT]
return course_runs, programs
......@@ -165,11 +166,17 @@ class TypeaheadSearchView(APIView):
paramType: query
required: true
type: string
- name: partner
description: "Partner short code"
paramType: query
required: false
type: string
"""
query = request.query_params.get('q')
partner = self.get_partner()
if not query:
raise ParseError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query)
raise ValidationError("The 'q' querystring parameter is required for searching.")
course_runs, programs = self.get_results(query, partner)
data = {'course_runs': course_runs, 'programs': programs}
serializer = serializers.TypeaheadSearchSerializer(data)
return Response(serializer.data, status=status.HTTP_200_OK)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment