Commit de69a28b by Tom Christie

Test and fix for #814.

parent 9d59e55c
...@@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend): ...@@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend):
""" """
filter_class = getattr(view, 'filter_class', None) filter_class = getattr(view, 'filter_class', None)
filter_fields = getattr(view, 'filter_fields', None) filter_fields = getattr(view, 'filter_fields', None)
view_model = getattr(view, 'model', None) model_cls = getattr(view, 'model', None)
queryset = getattr(view, 'queryset', None)
if model_cls is None and queryset is not None:
model_cls = queryset.model
if filter_class: if filter_class:
filter_model = filter_class.Meta.model filter_model = filter_class.Meta.model
assert issubclass(filter_model, view_model), \ assert issubclass(filter_model, model_cls), \
'FilterSet model %s does not match view model %s' % \ 'FilterSet model %s does not match view model %s' % \
(filter_model, view_model) (filter_model, model_cls)
return filter_class return filter_class
if filter_fields: if filter_fields:
assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \
'on a view which does not have a .model or .queryset attribute.'
class AutoFilterSet(self.default_filter_set): class AutoFilterSet(self.default_filter_set):
class Meta: class Meta:
model = view_model model = model_cls
fields = filter_fields fields = filter_fields
return AutoFilterSet return AutoFilterSet
......
...@@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse ...@@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, status, filters from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel from rest_framework.tests.models import FilterableItem, BasicModel
...@@ -52,6 +52,17 @@ if django_filters: ...@@ -52,6 +52,17 @@ if django_filters:
filter_class = SeveralFieldsFilter filter_class = SeveralFieldsFilter
filter_backend = filters.DjangoFilterBackend filter_backend = filters.DjangoFilterBackend
# Regression test for #814
class FilterableItemSerializer(serializers.ModelSerializer):
class Meta:
model = FilterableItem
class FilterFieldsQuerysetView(generics.ListCreateAPIView):
queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_fields = ['decimal', 'date']
filter_backend = filters.DjangoFilterBackend
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
url(r'^$', FilterClassRootView.as_view(), name='root-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'),
...@@ -115,6 +126,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -115,6 +126,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed') @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_filter_with_queryset(self):
"""
Regression test for #814.
"""
view = FilterFieldsQuerysetView.as_view()
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_class_root_view(self): def test_get_filtered_class_root_view(self):
""" """
GET requests to filtered ListCreateAPIView that have a filter_class set GET requests to filtered ListCreateAPIView that have a filter_class set
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment