Commit aed26b21 by Tom Christie

Drop out resources & mixins

parent 87b363f7
...@@ -18,7 +18,6 @@ __all__ = ( ...@@ -18,7 +18,6 @@ __all__ = (
'IsUserOrIsAnonReadOnly', 'IsUserOrIsAnonReadOnly',
'PerUserThrottling', 'PerUserThrottling',
'PerViewThrottling', 'PerViewThrottling',
'PerResourceThrottling'
) )
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
...@@ -253,16 +252,3 @@ class PerViewThrottling(BaseThrottle): ...@@ -253,16 +252,3 @@ class PerViewThrottling(BaseThrottle):
def get_cache_key(self): def get_cache_key(self):
return 'throttle_view_%s' % self.view.__class__.__name__ return 'throttle_view_%s' % self.view.__class__.__name__
class PerResourceThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be used against all views on
a given resource.
The class name of the resource is used as a unique identifier to
throttle against.
"""
def get_cache_key(self):
return 'throttle_resource_%s' % self.view.resource.__class__.__name__
"""
Customizable serialization.
"""
from django.db import models
from django.db.models.query import QuerySet
from django.utils.encoding import smart_unicode, is_protected_type, smart_str
import inspect
import types
# We register serializer classes, so that we can refer to them by their
# class names, if there are cyclical serialization heirachys.
_serializers = {}
def _field_to_tuple(field):
"""
Convert an item in the `fields` attribute into a 2-tuple.
"""
if isinstance(field, (tuple, list)):
return (field[0], field[1])
return (field, None)
def _fields_to_list(fields):
"""
Return a list of field tuples.
"""
return [_field_to_tuple(field) for field in fields or ()]
class _SkipField(Exception):
"""
Signals that a serialized field should be ignored.
We use this mechanism as the default behavior for ensuring
that we don't infinitely recurse when dealing with nested data.
"""
pass
class _RegisterSerializer(type):
"""
Metaclass to register serializers.
"""
def __new__(cls, name, bases, attrs):
# Build the class and register it.
ret = super(_RegisterSerializer, cls).__new__(cls, name, bases, attrs)
_serializers[name] = ret
return ret
class Serializer(object):
"""
Converts python objects into plain old native types suitable for
serialization. In particular it handles models and querysets.
The output format is specified by setting a number of attributes
on the class.
You may also override any of the serialization methods, to provide
for more flexible behavior.
Valid output types include anything that may be directly rendered into
json, xml etc...
"""
__metaclass__ = _RegisterSerializer
fields = ()
"""
Specify the fields to be serialized on a model or dict.
Overrides `include` and `exclude`.
"""
include = ()
"""
Fields to add to the default set to be serialized on a model/dict.
"""
exclude = ()
"""
Fields to remove from the default set to be serialized on a model/dict.
"""
rename = {}
"""
A dict of key->name to use for the field keys.
"""
related_serializer = None
"""
The default serializer class to use for any related models.
"""
depth = None
"""
The maximum depth to serialize to, or `None`.
"""
def __init__(self, depth=None, stack=[], **kwargs):
if depth is not None:
self.depth = depth
self.stack = stack
def get_fields(self, obj):
fields = self.fields
# If `fields` is not set, we use the default fields and modify
# them with `include` and `exclude`
if not fields:
default = self.get_default_fields(obj)
include = self.include or ()
exclude = self.exclude or ()
fields = set(default + list(include)) - set(exclude)
return fields
def get_default_fields(self, obj):
"""
Return the default list of field names/keys for a model instance/dict.
These are used if `fields` is not given.
"""
if isinstance(obj, models.Model):
opts = obj._meta
return [field.name for field in opts.fields + opts.many_to_many]
else:
return obj.keys()
def get_related_serializer(self, info):
# If an element in `fields` is a 2-tuple of (str, tuple)
# then the second element of the tuple is the fields to
# set on the related serializer
if isinstance(info, (list, tuple)):
class OnTheFlySerializer(self.__class__):
fields = info
return OnTheFlySerializer
# If an element in `fields` is a 2-tuple of (str, Serializer)
# then the second element of the tuple is the Serializer
# class to use for that field.
elif isinstance(info, type) and issubclass(info, Serializer):
return info
# If an element in `fields` is a 2-tuple of (str, str)
# then the second element of the tuple is the name of the Serializer
# class to use for that field.
#
# Black magic to deal with cyclical Serializer dependancies.
# Similar to what Django does for cyclically related models.
elif isinstance(info, str) and info in _serializers:
return _serializers[info]
# Otherwise use `related_serializer` or fall back to `Serializer`
return getattr(self, 'related_serializer') or Serializer
def serialize_key(self, key):
"""
Keys serialize to their string value,
unless they exist in the `rename` dict.
"""
return self.rename.get(smart_str(key), smart_str(key))
def serialize_val(self, key, obj, related_info):
"""
Convert a model field or dict value into a serializable representation.
"""
related_serializer = self.get_related_serializer(related_info)
if self.depth is None:
depth = None
elif self.depth <= 0:
return self.serialize_max_depth(obj)
else:
depth = self.depth - 1
if any([obj is elem for elem in self.stack]):
return self.serialize_recursion(obj)
else:
stack = self.stack[:]
stack.append(obj)
return related_serializer(depth=depth, stack=stack).serialize(obj)
def serialize_max_depth(self, obj):
"""
Determine how objects should be serialized once `depth` is exceeded.
The default behavior is to ignore the field.
"""
raise _SkipField
def serialize_recursion(self, obj):
"""
Determine how objects should be serialized if recursion occurs.
The default behavior is to ignore the field.
"""
raise _SkipField
def serialize_model(self, instance):
"""
Given a model instance or dict, serialize it to a dict..
"""
data = {}
fields = self.get_fields(instance)
# serialize each required field
for fname, related_info in _fields_to_list(fields):
try:
# we first check for a method 'fname' on self,
# 'fname's signature must be 'def fname(self, instance)'
meth = getattr(self, fname, None)
if (inspect.ismethod(meth) and
len(inspect.getargspec(meth)[0]) == 2):
obj = meth(instance)
elif hasattr(instance, '__contains__') and fname in instance:
# then check for a key 'fname' on the instance
obj = instance[fname]
elif hasattr(instance, smart_str(fname)):
# finally check for an attribute 'fname' on the instance
obj = getattr(instance, fname)
else:
continue
key = self.serialize_key(fname)
val = self.serialize_val(fname, obj, related_info)
data[key] = val
except _SkipField:
pass
return data
def serialize_iter(self, obj):
"""
Convert iterables into a serializable representation.
"""
return [self.serialize(item) for item in obj]
def serialize_func(self, obj):
"""
Convert no-arg methods and functions into a serializable representation.
"""
return self.serialize(obj())
def serialize_manager(self, obj):
"""
Convert a model manager into a serializable representation.
"""
return self.serialize_iter(obj.all())
def serialize_fallback(self, obj):
"""
Convert any unhandled object into a serializable representation.
"""
return smart_unicode(obj, strings_only=True)
def serialize(self, obj):
"""
Convert any object into a serializable representation.
"""
if isinstance(obj, (dict, models.Model)):
# Model instances & dictionaries
return self.serialize_model(obj)
elif isinstance(obj, (tuple, list, set, QuerySet, types.GeneratorType)):
# basic iterables
return self.serialize_iter(obj)
elif isinstance(obj, models.Manager):
# Manager objects
return self.serialize_manager(obj)
elif inspect.isfunction(obj) and not inspect.getargspec(obj)[0]:
# function with no args
return self.serialize_func(obj)
elif inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1:
# bound method
return self.serialize_func(obj)
# Protected types are passed through as is.
# (i.e. Primitives like None, numbers, dates, and Decimals.)
if is_protected_type(obj):
return obj
# All other values are converted to string.
return self.serialize_fallback(obj)
from django.test import TestCase # from django.test import TestCase
from django import forms # from django import forms
from djangorestframework.compat import RequestFactory # from djangorestframework.compat import RequestFactory
from djangorestframework.views import View # from djangorestframework.views import View
from djangorestframework.resources import FormResource # from djangorestframework.response import Response
from djangorestframework.response import Response
import StringIO # import StringIO
class UploadFilesTests(TestCase):
"""Check uploading of files"""
def setUp(self):
self.factory = RequestFactory()
def test_upload_file(self): # class UploadFilesTests(TestCase):
# """Check uploading of files"""
# def setUp(self):
# self.factory = RequestFactory()
class FileForm(forms.Form): # def test_upload_file(self):
file = forms.FileField()
class MockView(View): # class FileForm(forms.Form):
permissions = () # file = forms.FileField()
form = FileForm
def post(self, request, *args, **kwargs): # class MockView(View):
return Response({'FILE_NAME': self.CONTENT['file'].name, # permissions = ()
'FILE_CONTENT': self.CONTENT['file'].read()}) # form = FileForm
file = StringIO.StringIO('stuff') # def post(self, request, *args, **kwargs):
file.name = 'stuff.txt' # return Response({'FILE_NAME': self.CONTENT['file'].name,
request = self.factory.post('/', {'file': file}) # 'FILE_CONTENT': self.CONTENT['file'].read()})
view = MockView.as_view()
response = view(request)
self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
# file = StringIO.StringIO('stuff')
# file.name = 'stuff.txt'
# request = self.factory.post('/', {'file': file})
# view = MockView.as_view()
# response = view(request)
# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
from django.conf.urls.defaults import patterns, url # from django.conf.urls.defaults import patterns, url
from django.forms import ModelForm # from django.forms import ModelForm
from django.contrib.auth.models import Group, User # from django.contrib.auth.models import Group, User
from djangorestframework.resources import ModelResource # from djangorestframework.resources import ModelResource
from djangorestframework.views import ListOrCreateModelView, InstanceModelView # from djangorestframework.views import ListOrCreateModelView, InstanceModelView
from djangorestframework.tests.models import CustomUser # from djangorestframework.tests.models import CustomUser
from djangorestframework.tests.testcases import TestModelsTestCase # from djangorestframework.tests.testcases import TestModelsTestCase
class GroupResource(ModelResource): # class GroupResource(ModelResource):
model = Group # model = Group
class UserForm(ModelForm): # class UserForm(ModelForm):
class Meta: # class Meta:
model = User # model = User
exclude = ('last_login', 'date_joined') # exclude = ('last_login', 'date_joined')
class UserResource(ModelResource): # class UserResource(ModelResource):
model = User # model = User
form = UserForm # form = UserForm
class CustomUserResource(ModelResource): # class CustomUserResource(ModelResource):
model = CustomUser # model = CustomUser
urlpatterns = patterns('', # urlpatterns = patterns('',
url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'), # url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)), # url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'), # url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)), # url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'), # url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)), # url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
) # )
class ModelViewTests(TestModelsTestCase): # class ModelViewTests(TestModelsTestCase):
"""Test the model views djangorestframework provides""" # """Test the model views djangorestframework provides"""
urls = 'djangorestframework.tests.modelviews' # urls = 'djangorestframework.tests.modelviews'
def test_creation(self): # def test_creation(self):
"""Ensure that a model object can be created""" # """Ensure that a model object can be created"""
self.assertEqual(0, Group.objects.count()) # self.assertEqual(0, Group.objects.count())
response = self.client.post('/groups/', {'name': 'foo'}) # response = self.client.post('/groups/', {'name': 'foo'})
self.assertEqual(response.status_code, 201) # self.assertEqual(response.status_code, 201)
self.assertEqual(1, Group.objects.count()) # self.assertEqual(1, Group.objects.count())
self.assertEqual('foo', Group.objects.all()[0].name) # self.assertEqual('foo', Group.objects.all()[0].name)
def test_creation_with_m2m_relation(self): # def test_creation_with_m2m_relation(self):
"""Ensure that a model object with a m2m relation can be created""" # """Ensure that a model object with a m2m relation can be created"""
group = Group(name='foo') # group = Group(name='foo')
group.save() # group.save()
self.assertEqual(0, User.objects.count()) # self.assertEqual(0, User.objects.count())
response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]}) # response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
self.assertEqual(response.status_code, 201) # self.assertEqual(response.status_code, 201)
self.assertEqual(1, User.objects.count()) # self.assertEqual(1, User.objects.count())
user = User.objects.all()[0] # user = User.objects.all()[0]
self.assertEqual('bar', user.username) # self.assertEqual('bar', user.username)
self.assertEqual('baz', user.password) # self.assertEqual('baz', user.password)
self.assertEqual(1, user.groups.count()) # self.assertEqual(1, user.groups.count())
group = user.groups.all()[0] # group = user.groups.all()[0]
self.assertEqual('foo', group.name) # self.assertEqual('foo', group.name)
def test_creation_with_m2m_relation_through(self): # def test_creation_with_m2m_relation_through(self):
""" # """
Ensure that a model object with a m2m relation can be created where that # Ensure that a model object with a m2m relation can be created where that
relation uses a through table # relation uses a through table
""" # """
group = Group(name='foo') # group = Group(name='foo')
group.save() # group.save()
self.assertEqual(0, User.objects.count()) # self.assertEqual(0, User.objects.count())
response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]}) # response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
self.assertEqual(response.status_code, 201) # self.assertEqual(response.status_code, 201)
self.assertEqual(1, CustomUser.objects.count()) # self.assertEqual(1, CustomUser.objects.count())
user = CustomUser.objects.all()[0] # user = CustomUser.objects.all()[0]
self.assertEqual('bar', user.username) # self.assertEqual('bar', user.username)
self.assertEqual(1, user.groups.count()) # self.assertEqual(1, user.groups.count())
group = user.groups.all()[0] # group = user.groups.all()[0]
self.assertEqual('foo', group.name) # self.assertEqual('foo', group.name)
...@@ -8,8 +8,7 @@ from django.core.cache import cache ...@@ -8,8 +8,7 @@ from django.core.cache import cache
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework.views import View from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling from djangorestframework.permissions import PerUserThrottling, PerViewThrottling
from djangorestframework.resources import FormResource
from djangorestframework.response import Response from djangorestframework.response import Response
...@@ -25,11 +24,6 @@ class MockView_PerViewThrottling(MockView): ...@@ -25,11 +24,6 @@ class MockView_PerViewThrottling(MockView):
permission_classes = (PerViewThrottling,) permission_classes = (PerViewThrottling,)
class MockView_PerResourceThrottling(MockView):
permission_classes = (PerResourceThrottling,)
resource = FormResource
class MockView_MinuteThrottling(MockView): class MockView_MinuteThrottling(MockView):
throttle = '3/min' throttle = '3/min'
...@@ -98,12 +92,6 @@ class ThrottlingTests(TestCase): ...@@ -98,12 +92,6 @@ class ThrottlingTests(TestCase):
""" """
self.ensure_is_throttled(MockView_PerViewThrottling, 503) self.ensure_is_throttled(MockView_PerViewThrottling, 503)
def test_request_throttling_is_per_resource(self):
"""
Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
""" """
Ensure the response returns an X-Throttle field with status and next attributes Ensure the response returns an X-Throttle field with status and next attributes
......
...@@ -13,8 +13,7 @@ from django.views.decorators.csrf import csrf_exempt ...@@ -13,8 +13,7 @@ from django.views.decorators.csrf import csrf_exempt
from djangorestframework.compat import View as DjangoView, apply_markdown from djangorestframework.compat import View as DjangoView, apply_markdown
from djangorestframework.response import Response, ImmediateResponse from djangorestframework.response import Response, ImmediateResponse
from djangorestframework.request import Request from djangorestframework.request import Request
from djangorestframework.mixins import * from djangorestframework import renderers, parsers, authentication, permissions, status
from djangorestframework import resources, renderers, parsers, authentication, permissions, status
__all__ = ( __all__ = (
...@@ -29,7 +28,7 @@ __all__ = ( ...@@ -29,7 +28,7 @@ __all__ = (
def _remove_trailing_string(content, trailing): def _remove_trailing_string(content, trailing):
""" """
Strip trailing component `trailing` from `content` if it exists. Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view/resource classes. Used when generating names from view classes.
""" """
if content.endswith(trailing) and content != trailing: if content.endswith(trailing) and content != trailing:
return content[:-len(trailing)] return content[:-len(trailing)]
...@@ -54,40 +53,26 @@ def _remove_leading_indent(content): ...@@ -54,40 +53,26 @@ def _remove_leading_indent(content):
def _camelcase_to_spaces(content): def _camelcase_to_spaces(content):
""" """
Translate 'CamelCaseNames' to 'Camel Case Names'. Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view/resource classes. Used when generating names from view classes.
""" """
camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
return re.sub(camelcase_boundry, ' \\1', content).strip() return re.sub(camelcase_boundry, ' \\1', content).strip()
_resource_classes = ( class View(DjangoView):
None,
resources.Resource,
resources.FormResource,
resources.ModelResource
)
class View(ResourceMixin, DjangoView):
""" """
Handles incoming requests and maps them to REST operations. Handles incoming requests and maps them to REST operations.
Performs request deserialization, response serialization, authentication and input validation. Performs request deserialization, response serialization, authentication and input validation.
""" """
resource = None
"""
The resource to use when validating requests and filtering responses,
or `None` to use default behaviour.
"""
renderers = renderers.DEFAULT_RENDERERS renderers = renderers.DEFAULT_RENDERERS
""" """
List of renderer classes the resource can serialize the response with, ordered by preference. List of renderer classes the view can serialize the response with, ordered by preference.
""" """
parsers = parsers.DEFAULT_PARSERS parsers = parsers.DEFAULT_PARSERS
""" """
List of parser classes the resource can parse the request with. List of parser classes the view can parse the request with.
""" """
authentication = (authentication.UserLoggedInAuthentication, authentication = (authentication.UserLoggedInAuthentication,
...@@ -132,17 +117,8 @@ class View(ResourceMixin, DjangoView): ...@@ -132,17 +117,8 @@ class View(ResourceMixin, DjangoView):
Return the resource or view class name for use as this view's name. Return the resource or view class name for use as this view's name.
Override to customize. Override to customize.
""" """
# If this view has a resource that's been overridden, then use that resource for the name
if getattr(self, 'resource', None) not in _resource_classes:
name = self.resource.__name__
name = _remove_trailing_string(name, 'Resource')
name += getattr(self, '_suffix', '')
# If it's a view class with no resource then grok the name from the class name
else:
name = self.__class__.__name__ name = self.__class__.__name__
name = _remove_trailing_string(name, 'View') name = _remove_trailing_string(name, 'View')
return _camelcase_to_spaces(name) return _camelcase_to_spaces(name)
def get_description(self, html=False): def get_description(self, html=False):
...@@ -150,20 +126,8 @@ class View(ResourceMixin, DjangoView): ...@@ -150,20 +126,8 @@ class View(ResourceMixin, DjangoView):
Return the resource or view docstring for use as this view's description. Return the resource or view docstring for use as this view's description.
Override to customize. Override to customize.
""" """
description = None
# If this view has a resource that's been overridden,
# then try to use the resource's docstring
if getattr(self, 'resource', None) not in _resource_classes:
description = self.resource.__doc__
# Otherwise use the view docstring
if not description:
description = self.__doc__ or '' description = self.__doc__ or ''
description = _remove_leading_indent(description) description = _remove_leading_indent(description)
if html: if html:
return self.markup_description(description) return self.markup_description(description)
return description return description
...@@ -184,7 +148,7 @@ class View(ResourceMixin, DjangoView): ...@@ -184,7 +148,7 @@ class View(ResourceMixin, DjangoView):
a handler method. a handler method.
""" """
content = { content = {
'detail': "Method '%s' not allowed on this resource." % request.method 'detail': "Method '%s' not allowed." % request.method
} }
raise ImmediateResponse(content, status.HTTP_405_METHOD_NOT_ALLOWED) raise ImmediateResponse(content, status.HTTP_405_METHOD_NOT_ALLOWED)
...@@ -283,10 +247,6 @@ class View(ResourceMixin, DjangoView): ...@@ -283,10 +247,6 @@ class View(ResourceMixin, DjangoView):
response = handler(request, *args, **kwargs) response = handler(request, *args, **kwargs)
if isinstance(response, Response):
# Pre-serialize filtering (eg filter complex objects into natively serializable types)
response.raw_content = self.filter_response(response.raw_content)
except ImmediateResponse, exc: except ImmediateResponse, exc:
response = exc.response response = exc.response
...@@ -307,31 +267,3 @@ class View(ResourceMixin, DjangoView): ...@@ -307,31 +267,3 @@ class View(ResourceMixin, DjangoView):
field_name_types[name] = field.__class__.__name__ field_name_types[name] = field.__class__.__name__
content['fields'] = field_name_types content['fields'] = field_name_types
raise ImmediateResponse(content, status=status.HTTP_200_OK) raise ImmediateResponse(content, status=status.HTTP_200_OK)
class ModelView(View):
"""
A RESTful view that maps to a model in the database.
"""
resource = resources.ModelResource
class InstanceModelView(ReadModelMixin, UpdateModelMixin, DeleteModelMixin, ModelView):
"""
A view which provides default operations for read/update/delete against a model instance.
"""
_suffix = 'Instance'
class ListModelView(ListModelMixin, ModelView):
"""
A view which provides default operations for list, against a model in the database.
"""
_suffix = 'List'
class ListOrCreateModelView(ListModelMixin, CreateModelMixin, ModelView):
"""
A view which provides default operations for list and create, against a model in the database.
"""
_suffix = 'List'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment