Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

""" 

Provides generic filtering backends that can be used to filter the results 

returned by list views. 

""" 

from __future__ import unicode_literals 

from django.db import models 

from rest_framework.compat import django_filters, six 

from functools import reduce 

import operator 

 

FilterSet = django_filters and django_filters.FilterSet or None 

 

 

class BaseFilterBackend(object): 

    """ 

    A base class from which all filter backend classes should inherit. 

    """ 

 

    def filter_queryset(self, request, queryset, view): 

        """ 

        Return a filtered queryset. 

        """ 

        raise NotImplementedError(".filter_queryset() must be overridden.") 

 

 

class DjangoFilterBackend(BaseFilterBackend): 

    """ 

    A filter backend that uses django-filter. 

    """ 

    default_filter_set = FilterSet 

 

    def __init__(self): 

        assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' 

 

    def get_filter_class(self, view, queryset=None): 

        """ 

        Return the django-filters `FilterSet` used to filter the queryset. 

        """ 

        filter_class = getattr(view, 'filter_class', None) 

        filter_fields = getattr(view, 'filter_fields', None) 

 

        if filter_class: 

            filter_model = filter_class.Meta.model 

 

            assert issubclass(filter_model, queryset.model), \ 

                'FilterSet model %s does not match queryset model %s' % \ 

                (filter_model, queryset.model) 

 

            return filter_class 

 

        if filter_fields: 

            class AutoFilterSet(self.default_filter_set): 

                class Meta: 

                    model = queryset.model 

                    fields = filter_fields 

            return AutoFilterSet 

 

        return None 

 

    def filter_queryset(self, request, queryset, view): 

        filter_class = self.get_filter_class(view, queryset) 

 

        if filter_class: 

            return filter_class(request.QUERY_PARAMS, queryset=queryset).qs 

 

        return queryset 

 

 

class SearchFilter(BaseFilterBackend): 

    search_param = 'search'  # The URL query parameter used for the search. 

 

    def get_search_terms(self, request): 

        """ 

        Search terms are set by a ?search=... query parameter, 

        and may be comma and/or whitespace delimited. 

        """ 

        params = request.QUERY_PARAMS.get(self.search_param, '') 

        return params.replace(',', ' ').split() 

 

    def construct_search(self, field_name): 

        if field_name.startswith('^'): 

            return "%s__istartswith" % field_name[1:] 

        elif field_name.startswith('='): 

            return "%s__iexact" % field_name[1:] 

        elif field_name.startswith('@'): 

            return "%s__search" % field_name[1:] 

        else: 

            return "%s__icontains" % field_name 

 

    def filter_queryset(self, request, queryset, view): 

        search_fields = getattr(view, 'search_fields', None) 

 

        if not search_fields: 

            return queryset 

 

        orm_lookups = [self.construct_search(str(search_field)) 

                       for search_field in search_fields] 

 

        for search_term in self.get_search_terms(request): 

            or_queries = [models.Q(**{orm_lookup: search_term}) 

                          for orm_lookup in orm_lookups] 

            queryset = queryset.filter(reduce(operator.or_, or_queries)) 

 

        return queryset 

 

 

class OrderingFilter(BaseFilterBackend): 

    ordering_param = 'ordering'  # The URL query parameter used for the ordering. 

 

    def get_ordering(self, request): 

        """ 

        Search terms are set by a ?search=... query parameter, 

        and may be comma and/or whitespace delimited. 

        """ 

        params = request.QUERY_PARAMS.get(self.ordering_param) 

        if params: 

            return [param.strip() for param in params.split(',')] 

 

    def get_default_ordering(self, view): 

        ordering = getattr(view, 'ordering', None) 

        if isinstance(ordering, six.string_types): 

            return (ordering,) 

        return ordering 

 

    def remove_invalid_fields(self, queryset, ordering): 

        field_names = [field.name for field in queryset.model._meta.fields] 

        return [term for term in ordering if term.lstrip('-') in field_names] 

 

    def filter_queryset(self, request, queryset, view): 

        ordering = self.get_ordering(request) 

 

        if ordering: 

            # Skip any incorrect parameters 

            ordering = self.remove_invalid_fields(queryset, ordering) 

 

        if not ordering: 

            # Use 'ordering' attribtue by default 

            ordering = self.get_default_ordering(view) 

 

        if ordering: 

            return queryset.order_by(*ordering) 

 

        return queryset