Commit aed26b21 by Tom Christie

Drop out resources & mixins

parent 87b363f7
......@@ -18,7 +18,6 @@ __all__ = (
'IsUserOrIsAnonReadOnly',
'PerUserThrottling',
'PerViewThrottling',
'PerResourceThrottling'
)
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
......@@ -253,16 +252,3 @@ class PerViewThrottling(BaseThrottle):
def get_cache_key(self):
return 'throttle_view_%s' % self.view.__class__.__name__
class PerResourceThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be used against all views on
a given resource.
The class name of the resource is used as a unique identifier to
throttle against.
"""
def get_cache_key(self):
return 'throttle_resource_%s' % self.view.resource.__class__.__name__
"""
Customizable serialization.
"""
from django.db import models
from django.db.models.query import QuerySet
from django.utils.encoding import smart_unicode, is_protected_type, smart_str
import inspect
import types
# We register serializer classes, so that we can refer to them by their
# class names, if there are cyclical serialization heirachys.
_serializers = {}
def _field_to_tuple(field):
"""
Convert an item in the `fields` attribute into a 2-tuple.
"""
if isinstance(field, (tuple, list)):
return (field[0], field[1])
return (field, None)
def _fields_to_list(fields):
"""
Return a list of field tuples.
"""
return [_field_to_tuple(field) for field in fields or ()]
class _SkipField(Exception):
"""
Signals that a serialized field should be ignored.
We use this mechanism as the default behavior for ensuring
that we don't infinitely recurse when dealing with nested data.
"""
pass
class _RegisterSerializer(type):
"""
Metaclass to register serializers.
"""
def __new__(cls, name, bases, attrs):
# Build the class and register it.
ret = super(_RegisterSerializer, cls).__new__(cls, name, bases, attrs)
_serializers[name] = ret
return ret
class Serializer(object):
"""
Converts python objects into plain old native types suitable for
serialization. In particular it handles models and querysets.
The output format is specified by setting a number of attributes
on the class.
You may also override any of the serialization methods, to provide
for more flexible behavior.
Valid output types include anything that may be directly rendered into
json, xml etc...
"""
__metaclass__ = _RegisterSerializer
fields = ()
"""
Specify the fields to be serialized on a model or dict.
Overrides `include` and `exclude`.
"""
include = ()
"""
Fields to add to the default set to be serialized on a model/dict.
"""
exclude = ()
"""
Fields to remove from the default set to be serialized on a model/dict.
"""
rename = {}
"""
A dict of key->name to use for the field keys.
"""
related_serializer = None
"""
The default serializer class to use for any related models.
"""
depth = None
"""
The maximum depth to serialize to, or `None`.
"""
def __init__(self, depth=None, stack=[], **kwargs):
if depth is not None:
self.depth = depth
self.stack = stack
def get_fields(self, obj):
fields = self.fields
# If `fields` is not set, we use the default fields and modify
# them with `include` and `exclude`
if not fields:
default = self.get_default_fields(obj)
include = self.include or ()
exclude = self.exclude or ()
fields = set(default + list(include)) - set(exclude)
return fields
def get_default_fields(self, obj):
"""
Return the default list of field names/keys for a model instance/dict.
These are used if `fields` is not given.
"""
if isinstance(obj, models.Model):
opts = obj._meta
return [field.name for field in opts.fields + opts.many_to_many]
else:
return obj.keys()
def get_related_serializer(self, info):
# If an element in `fields` is a 2-tuple of (str, tuple)
# then the second element of the tuple is the fields to
# set on the related serializer
if isinstance(info, (list, tuple)):
class OnTheFlySerializer(self.__class__):
fields = info
return OnTheFlySerializer
# If an element in `fields` is a 2-tuple of (str, Serializer)
# then the second element of the tuple is the Serializer
# class to use for that field.
elif isinstance(info, type) and issubclass(info, Serializer):
return info
# If an element in `fields` is a 2-tuple of (str, str)
# then the second element of the tuple is the name of the Serializer
# class to use for that field.
#
# Black magic to deal with cyclical Serializer dependancies.
# Similar to what Django does for cyclically related models.
elif isinstance(info, str) and info in _serializers:
return _serializers[info]
# Otherwise use `related_serializer` or fall back to `Serializer`
return getattr(self, 'related_serializer') or Serializer
def serialize_key(self, key):
"""
Keys serialize to their string value,
unless they exist in the `rename` dict.
"""
return self.rename.get(smart_str(key), smart_str(key))
def serialize_val(self, key, obj, related_info):
"""
Convert a model field or dict value into a serializable representation.
"""
related_serializer = self.get_related_serializer(related_info)
if self.depth is None:
depth = None
elif self.depth <= 0:
return self.serialize_max_depth(obj)
else:
depth = self.depth - 1
if any([obj is elem for elem in self.stack]):
return self.serialize_recursion(obj)
else:
stack = self.stack[:]
stack.append(obj)
return related_serializer(depth=depth, stack=stack).serialize(obj)
def serialize_max_depth(self, obj):
"""
Determine how objects should be serialized once `depth` is exceeded.
The default behavior is to ignore the field.
"""
raise _SkipField
def serialize_recursion(self, obj):
"""
Determine how objects should be serialized if recursion occurs.
The default behavior is to ignore the field.
"""
raise _SkipField
def serialize_model(self, instance):
"""
Given a model instance or dict, serialize it to a dict..
"""
data = {}
fields = self.get_fields(instance)
# serialize each required field
for fname, related_info in _fields_to_list(fields):
try:
# we first check for a method 'fname' on self,
# 'fname's signature must be 'def fname(self, instance)'
meth = getattr(self, fname, None)
if (inspect.ismethod(meth) and
len(inspect.getargspec(meth)[0]) == 2):
obj = meth(instance)
elif hasattr(instance, '__contains__') and fname in instance:
# then check for a key 'fname' on the instance
obj = instance[fname]
elif hasattr(instance, smart_str(fname)):
# finally check for an attribute 'fname' on the instance
obj = getattr(instance, fname)
else:
continue
key = self.serialize_key(fname)
val = self.serialize_val(fname, obj, related_info)
data[key] = val
except _SkipField:
pass
return data
def serialize_iter(self, obj):
"""
Convert iterables into a serializable representation.
"""
return [self.serialize(item) for item in obj]
def serialize_func(self, obj):
"""
Convert no-arg methods and functions into a serializable representation.
"""
return self.serialize(obj())
def serialize_manager(self, obj):
"""
Convert a model manager into a serializable representation.
"""
return self.serialize_iter(obj.all())
def serialize_fallback(self, obj):
"""
Convert any unhandled object into a serializable representation.
"""
return smart_unicode(obj, strings_only=True)
def serialize(self, obj):
"""
Convert any object into a serializable representation.
"""
if isinstance(obj, (dict, models.Model)):
# Model instances & dictionaries
return self.serialize_model(obj)
elif isinstance(obj, (tuple, list, set, QuerySet, types.GeneratorType)):
# basic iterables
return self.serialize_iter(obj)
elif isinstance(obj, models.Manager):
# Manager objects
return self.serialize_manager(obj)
elif inspect.isfunction(obj) and not inspect.getargspec(obj)[0]:
# function with no args
return self.serialize_func(obj)
elif inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1:
# bound method
return self.serialize_func(obj)
# Protected types are passed through as is.
# (i.e. Primitives like None, numbers, dates, and Decimals.)
if is_protected_type(obj):
return obj
# All other values are converted to string.
return self.serialize_fallback(obj)
from django.test import TestCase
from django import forms
# from django.test import TestCase
# from django import forms
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
from djangorestframework.resources import FormResource
from djangorestframework.response import Response
# from djangorestframework.compat import RequestFactory
# from djangorestframework.views import View
# from djangorestframework.response import Response
import StringIO
# import StringIO
class UploadFilesTests(TestCase):
"""Check uploading of files"""
def setUp(self):
self.factory = RequestFactory()
def test_upload_file(self):
# class UploadFilesTests(TestCase):
# """Check uploading of files"""
# def setUp(self):
# self.factory = RequestFactory()
class FileForm(forms.Form):
file = forms.FileField()
# def test_upload_file(self):
class MockView(View):
permissions = ()
form = FileForm
# class FileForm(forms.Form):
# file = forms.FileField()
def post(self, request, *args, **kwargs):
return Response({'FILE_NAME': self.CONTENT['file'].name,
'FILE_CONTENT': self.CONTENT['file'].read()})
# class MockView(View):
# permissions = ()
# form = FileForm
file = StringIO.StringIO('stuff')
file.name = 'stuff.txt'
request = self.factory.post('/', {'file': file})
view = MockView.as_view()
response = view(request)
self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
# def post(self, request, *args, **kwargs):
# return Response({'FILE_NAME': self.CONTENT['file'].name,
# 'FILE_CONTENT': self.CONTENT['file'].read()})
# file = StringIO.StringIO('stuff')
# file.name = 'stuff.txt'
# request = self.factory.post('/', {'file': file})
# view = MockView.as_view()
# response = view(request)
# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
from django.conf.urls.defaults import patterns, url
from django.forms import ModelForm
from django.contrib.auth.models import Group, User
from djangorestframework.resources import ModelResource
from djangorestframework.views import ListOrCreateModelView, InstanceModelView
from djangorestframework.tests.models import CustomUser
from djangorestframework.tests.testcases import TestModelsTestCase
# from django.conf.urls.defaults import patterns, url
# from django.forms import ModelForm
# from django.contrib.auth.models import Group, User
# from djangorestframework.resources import ModelResource
# from djangorestframework.views import ListOrCreateModelView, InstanceModelView
# from djangorestframework.tests.models import CustomUser
# from djangorestframework.tests.testcases import TestModelsTestCase
class GroupResource(ModelResource):
model = Group
# class GroupResource(ModelResource):
# model = Group
class UserForm(ModelForm):
class Meta:
model = User
exclude = ('last_login', 'date_joined')
# class UserForm(ModelForm):
# class Meta:
# model = User
# exclude = ('last_login', 'date_joined')
class UserResource(ModelResource):
model = User
form = UserForm
# class UserResource(ModelResource):
# model = User
# form = UserForm
class CustomUserResource(ModelResource):
model = CustomUser
# class CustomUserResource(ModelResource):
# model = CustomUser
urlpatterns = patterns('',
url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
)
# urlpatterns = patterns('',
# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
# )
class ModelViewTests(TestModelsTestCase):
"""Test the model views djangorestframework provides"""
urls = 'djangorestframework.tests.modelviews'
# class ModelViewTests(TestModelsTestCase):
# """Test the model views djangorestframework provides"""
# urls = 'djangorestframework.tests.modelviews'
def test_creation(self):
"""Ensure that a model object can be created"""
self.assertEqual(0, Group.objects.count())
# def test_creation(self):
# """Ensure that a model object can be created"""
# self.assertEqual(0, Group.objects.count())
response = self.client.post('/groups/', {'name': 'foo'})
# response = self.client.post('/groups/', {'name': 'foo'})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, Group.objects.count())
self.assertEqual('foo', Group.objects.all()[0].name)
# self.assertEqual(response.status_code, 201)
# self.assertEqual(1, Group.objects.count())
# self.assertEqual('foo', Group.objects.all()[0].name)
def test_creation_with_m2m_relation(self):
"""Ensure that a model object with a m2m relation can be created"""
group = Group(name='foo')
group.save()
self.assertEqual(0, User.objects.count())
# def test_creation_with_m2m_relation(self):
# """Ensure that a model object with a m2m relation can be created"""
# group = Group(name='foo')
# group.save()
# self.assertEqual(0, User.objects.count())
response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, User.objects.count())
# self.assertEqual(response.status_code, 201)
# self.assertEqual(1, User.objects.count())
user = User.objects.all()[0]
self.assertEqual('bar', user.username)
self.assertEqual('baz', user.password)
self.assertEqual(1, user.groups.count())
# user = User.objects.all()[0]
# self.assertEqual('bar', user.username)
# self.assertEqual('baz', user.password)
# self.assertEqual(1, user.groups.count())
group = user.groups.all()[0]
self.assertEqual('foo', group.name)
# group = user.groups.all()[0]
# self.assertEqual('foo', group.name)
def test_creation_with_m2m_relation_through(self):
"""
Ensure that a model object with a m2m relation can be created where that
relation uses a through table
"""
group = Group(name='foo')
group.save()
self.assertEqual(0, User.objects.count())
# def test_creation_with_m2m_relation_through(self):
# """
# Ensure that a model object with a m2m relation can be created where that
# relation uses a through table
# """
# group = Group(name='foo')
# group.save()
# self.assertEqual(0, User.objects.count())
response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, CustomUser.objects.count())
# self.assertEqual(response.status_code, 201)
# self.assertEqual(1, CustomUser.objects.count())
user = CustomUser.objects.all()[0]
self.assertEqual('bar', user.username)
self.assertEqual(1, user.groups.count())
# user = CustomUser.objects.all()[0]
# self.assertEqual('bar', user.username)
# self.assertEqual(1, user.groups.count())
group = user.groups.all()[0]
self.assertEqual('foo', group.name)
# group = user.groups.all()[0]
# self.assertEqual('foo', group.name)
......@@ -8,8 +8,7 @@ from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling
from djangorestframework.response import Response
......@@ -25,11 +24,6 @@ class MockView_PerViewThrottling(MockView):
permission_classes = (PerViewThrottling,)
class MockView_PerResourceThrottling(MockView):
permission_classes = (PerResourceThrottling,)
resource = FormResource
class MockView_MinuteThrottling(MockView):
throttle = '3/min'
......@@ -98,12 +92,6 @@ class ThrottlingTests(TestCase):
"""
self.ensure_is_throttled(MockView_PerViewThrottling, 503)
def test_request_throttling_is_per_resource(self):
"""
Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
......
......@@ -13,8 +13,7 @@ from django.views.decorators.csrf import csrf_exempt
from djangorestframework.compat import View as DjangoView, apply_markdown
from djangorestframework.response import Response, ImmediateResponse
from djangorestframework.request import Request
from djangorestframework.mixins import *
from djangorestframework import resources, renderers, parsers, authentication, permissions, status
from djangorestframework import renderers, parsers, authentication, permissions, status
__all__ = (
......@@ -29,7 +28,7 @@ __all__ = (
def _remove_trailing_string(content, trailing):
"""
Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view/resource classes.
Used when generating names from view classes.
"""
if content.endswith(trailing) and content != trailing:
return content[:-len(trailing)]
......@@ -54,40 +53,26 @@ def _remove_leading_indent(content):
def _camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view/resource classes.
Used when generating names from view classes.
"""
camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
return re.sub(camelcase_boundry, ' \\1', content).strip()
_resource_classes = (
None,
resources.Resource,
resources.FormResource,
resources.ModelResource
)
class View(ResourceMixin, DjangoView):
class View(DjangoView):
"""
Handles incoming requests and maps them to REST operations.
Performs request deserialization, response serialization, authentication and input validation.
"""
resource = None
"""
The resource to use when validating requests and filtering responses,
or `None` to use default behaviour.
"""
renderers = renderers.DEFAULT_RENDERERS
"""
List of renderer classes the resource can serialize the response with, ordered by preference.
List of renderer classes the view can serialize the response with, ordered by preference.
"""
parsers = parsers.DEFAULT_PARSERS
"""
List of parser classes the resource can parse the request with.
List of parser classes the view can parse the request with.
"""
authentication = (authentication.UserLoggedInAuthentication,
......@@ -132,17 +117,8 @@ class View(ResourceMixin, DjangoView):
Return the resource or view class name for use as this view's name.
Override to customize.
"""
# If this view has a resource that's been overridden, then use that resource for the name
if getattr(self, 'resource', None) not in _resource_classes:
name = self.resource.__name__
name = _remove_trailing_string(name, 'Resource')
name += getattr(self, '_suffix', '')
# If it's a view class with no resource then grok the name from the class name
else:
name = self.__class__.__name__
name = _remove_trailing_string(name, 'View')
return _camelcase_to_spaces(name)
def get_description(self, html=False):
......@@ -150,20 +126,8 @@ class View(ResourceMixin, DjangoView):
Return the resource or view docstring for use as this view's description.
Override to customize.
"""
description = None
# If this view has a resource that's been overridden,
# then try to use the resource's docstring
if getattr(self, 'resource', None) not in _resource_classes:
description = self.resource.__doc__
# Otherwise use the view docstring
if not description:
description = self.__doc__ or ''
description = _remove_leading_indent(description)
if html:
return self.markup_description(description)
return description
......@@ -184,7 +148,7 @@ class View(ResourceMixin, DjangoView):
a handler method.
"""
content = {
'detail': "Method '%s' not allowed on this resource." % request.method
'detail': "Method '%s' not allowed." % request.method
}
raise ImmediateResponse(content, status.HTTP_405_METHOD_NOT_ALLOWED)
......@@ -283,10 +247,6 @@ class View(ResourceMixin, DjangoView):
response = handler(request, *args, **kwargs)
if isinstance(response, Response):
# Pre-serialize filtering (eg filter complex objects into natively serializable types)
response.raw_content = self.filter_response(response.raw_content)
except ImmediateResponse, exc:
response = exc.response
......@@ -307,31 +267,3 @@ class View(ResourceMixin, DjangoView):
field_name_types[name] = field.__class__.__name__
content['fields'] = field_name_types
raise ImmediateResponse(content, status=status.HTTP_200_OK)
class ModelView(View):
"""
A RESTful view that maps to a model in the database.
"""
resource = resources.ModelResource
class InstanceModelView(ReadModelMixin, UpdateModelMixin, DeleteModelMixin, ModelView):
"""
A view which provides default operations for read/update/delete against a model instance.
"""
_suffix = 'Instance'
class ListModelView(ListModelMixin, ModelView):
"""
A view which provides default operations for list, against a model in the database.
"""
_suffix = 'List'
class ListOrCreateModelView(ListModelMixin, CreateModelMixin, ModelView):
"""
A view which provides default operations for list and create, against a model in the database.
"""
_suffix = 'List'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment