Commit e05021c8 by Tom Christie

Guard against erronous direct .queryset evaluation in CBVs.

parent 9d136abb
...@@ -7,6 +7,7 @@ import inspect ...@@ -7,6 +7,7 @@ import inspect
import warnings import warnings
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import models
from django.http import Http404 from django.http import Http404
from django.utils import six from django.utils import six
from django.utils.encoding import smart_text from django.utils.encoding import smart_text
...@@ -118,8 +119,19 @@ class APIView(View): ...@@ -118,8 +119,19 @@ class APIView(View):
This allows us to discover information about the view when we do URL This allows us to discover information about the view when we do URL
reverse lookups. Used for breadcrumb generation. reverse lookups. Used for breadcrumb generation.
""" """
if isinstance(getattr(cls, 'queryset', None), models.QuerySet):
def force_evaluation():
raise AssertionError(
'Do not evaluate the `.queryset` attribute directly, '
'as the result will be cached and reused between requests. '
'Use `.all()` or call `.get_queryset()` instead.'
)
cls.queryset._fetch_all = force_evaluation
view = super(APIView, cls).as_view(**initkwargs) view = super(APIView, cls).as_view(**initkwargs)
view.cls = cls view.cls = cls
# Note: session based authentication is explicitly CSRF validated, # Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt. # all other authentication is CSRF exempt.
return csrf_exempt(view) return csrf_exempt(view)
......
from __future__ import unicode_literals from __future__ import unicode_literals
import django import django
import pytest
from django.db import models from django.db import models
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from rest_framework import generics, renderers, serializers, status from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from tests.models import ( from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel
...@@ -527,3 +529,17 @@ class TestFilterBackendAppliedToViews(TestCase): ...@@ -527,3 +529,17 @@ class TestFilterBackendAppliedToViews(TestCase):
response = view(request).render() response = view(request).render()
self.assertContains(response, 'field_b') self.assertContains(response, 'field_b')
self.assertNotContains(response, 'field_a') self.assertNotContains(response, 'field_a')
class TestGuardedQueryset(TestCase):
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
def get(self, request):
return Response(list(self.queryset))
view = QuerysetAccessError.as_view()
request = factory.get('/')
with pytest.raises(AssertionError):
view(request).render()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment