Commit 8ce36d2b by Tom Christie

SearchFilter and tests

parent 773a92ea
...@@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend): ...@@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend): class SearchFilter(BaseFilterBackend):
search_param = 'search'
def construct_search(self, field_name): def construct_search(self, field_name):
if field_name.startswith('^'): if field_name.startswith('^'):
return "%s__istartswith" % field_name[1:] return "%s__istartswith" % field_name[1:]
...@@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend): ...@@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend):
if not search_fields: if not search_fields:
return None return None
search_terms = request.QUERY_PARAMS.get(self.search_param)
orm_lookups = [self.construct_search(str(search_field)) orm_lookups = [self.construct_search(str(search_field))
for search_field in self.search_fields] for search_field in search_fields]
for bit in self.query.split():
for bit in search_terms.split():
or_queries = [models.Q(**{orm_lookup: bit}) or_queries = [models.Q(**{orm_lookup: bit})
for orm_lookup in orm_lookups] for orm_lookup in orm_lookups]
queryset = queryset.filter(reduce(operator.or_, or_queries)) queryset = queryset.filter(reduce(operator.or_, or_queries))
return queryset return queryset
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.db import models
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, serializers, status, filters from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel from rest_framework.tests.models import BasicModel
factory = RequestFactory() factory = RequestFactory()
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
if django_filters: if django_filters:
# Basic filter on a list view. # Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView): class FilterFieldsRootView(generics.ListCreateAPIView):
...@@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): ...@@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data) self.assertEqual(response.data, valid_item_data)
class SearchFilterModel(models.Model):
title = models.CharField(max_length=20)
text = models.CharField(max_length=100)
class SearchFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
#
# z abc
# zz bcd
# zzz cde
# ...
for idx in range(10):
title = 'z' * (idx + 1)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save()
def test_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view()
request = factory.get('?search=b')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 1, 'title': u'z', 'text': u'abc'},
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
]
)
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
view = SearchListView.as_view()
request = factory.get('?search=zzz')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 3, 'title': u'zzz', 'text': u'cde'}
]
)
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
view = SearchListView.as_view()
request = factory.get('?search=b')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
]
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment