Commit 39e13a0d by Philip Douglas

Merge remote-tracking branch 'upstream/master'

parents ef7ce344 f5c34926
...@@ -46,6 +46,11 @@ The default authentication schemes may be set globally, using the `DEFAULT_AUTHE ...@@ -46,6 +46,11 @@ The default authentication schemes may be set globally, using the `DEFAULT_AUTHE
You can also set the authentication scheme on a per-view or per-viewset basis, You can also set the authentication scheme on a per-view or per-viewset basis,
using the `APIView` class based views. using the `APIView` class based views.
from rest_framework.authentication import SessionAuthentication, BasicAuthentication
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
class ExampleView(APIView): class ExampleView(APIView):
authentication_classes = (SessionAuthentication, BasicAuthentication) authentication_classes = (SessionAuthentication, BasicAuthentication)
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
...@@ -157,11 +162,16 @@ The `curl` command line tool may be useful for testing token authenticated APIs. ...@@ -157,11 +162,16 @@ The `curl` command line tool may be useful for testing token authenticated APIs.
If you want every user to have an automatically generated Token, you can simply catch the User's `post_save` signal. If you want every user to have an automatically generated Token, you can simply catch the User's `post_save` signal.
from django.dispatch import receiver
from rest_framework.authtoken.models import Token
@receiver(post_save, sender=User) @receiver(post_save, sender=User)
def create_auth_token(sender, instance=None, created=False, **kwargs): def create_auth_token(sender, instance=None, created=False, **kwargs):
if created: if created:
Token.objects.create(user=instance) Token.objects.create(user=instance)
Note that you'll want to ensure you place this code snippet in an installed `models.py` module, or some other location that will be imported by Django on startup.
If you've already created some users, you can generate tokens for all existing users like this: If you've already created some users, you can generate tokens for all existing users like this:
from django.contrib.auth.models import User from django.contrib.auth.models import User
...@@ -336,6 +346,10 @@ If the `.authenticate_header()` method is not overridden, the authentication sch ...@@ -336,6 +346,10 @@ If the `.authenticate_header()` method is not overridden, the authentication sch
The following example will authenticate any incoming request as the user given by the username in a custom request header named 'X_USERNAME'. The following example will authenticate any incoming request as the user given by the username in a custom request header named 'X_USERNAME'.
from django.contrib.auth.models import User
from rest_framework import authentication
from rest_framework import exceptions
class ExampleAuthentication(authentication.BaseAuthentication): class ExampleAuthentication(authentication.BaseAuthentication):
def authenticate(self, request): def authenticate(self, request):
username = request.META.get('X_USERNAME') username = request.META.get('X_USERNAME')
...@@ -390,4 +404,4 @@ The [Django OAuth2 Consumer][doac] library from [Rediker Software][rediker] is a ...@@ -390,4 +404,4 @@ The [Django OAuth2 Consumer][doac] library from [Rediker Software][rediker] is a
[oauthlib]: https://github.com/idan/oauthlib [oauthlib]: https://github.com/idan/oauthlib
[doac]: https://github.com/Rediker-Software/doac [doac]: https://github.com/Rediker-Software/doac
[rediker]: https://github.com/Rediker-Software [rediker]: https://github.com/Rediker-Software
[doac-rest-framework]: https://github.com/Rediker-Software/doac/blob/master/docs/markdown/integrations.md# [doac-rest-framework]: https://github.com/Rediker-Software/doac/blob/master/docs/integrations.md#
...@@ -54,6 +54,8 @@ The `select_renderer()` method should return a two-tuple of (renderer instance, ...@@ -54,6 +54,8 @@ The `select_renderer()` method should return a two-tuple of (renderer instance,
The following is a custom content negotiation class which ignores the client The following is a custom content negotiation class which ignores the client
request when selecting the appropriate parser or renderer. request when selecting the appropriate parser or renderer.
from rest_framework.negotiation import BaseContentNegotiation
class IgnoreClientContentNegotiation(BaseContentNegotiation): class IgnoreClientContentNegotiation(BaseContentNegotiation):
def select_parser(self, request, parsers): def select_parser(self, request, parsers):
""" """
...@@ -77,6 +79,10 @@ The default content negotiation class may be set globally, using the `DEFAULT_CO ...@@ -77,6 +79,10 @@ The default content negotiation class may be set globally, using the `DEFAULT_CO
You can also set the content negotiation used for an individual view, or viewset, using the `APIView` class based views. You can also set the content negotiation used for an individual view, or viewset, using the `APIView` class based views.
from myapp.negotiation import IgnoreClientContentNegotiation
from rest_framework.response import Response
from rest_framework.views import APIView
class NoNegotiationView(APIView): class NoNegotiationView(APIView):
""" """
An example view that does not perform content negotiation. An example view that does not perform content negotiation.
......
...@@ -28,11 +28,54 @@ For example, the following request: ...@@ -28,11 +28,54 @@ For example, the following request:
Might receive an error response indicating that the `DELETE` method is not allowed on that resource: Might receive an error response indicating that the `DELETE` method is not allowed on that resource:
HTTP/1.1 405 Method Not Allowed HTTP/1.1 405 Method Not Allowed
Content-Type: application/json; charset=utf-8 Content-Type: application/json
Content-Length: 42 Content-Length: 42
{"detail": "Method 'DELETE' not allowed."} {"detail": "Method 'DELETE' not allowed."}
## Custom exception handling
You can implement custom exception handling by creating a handler function that converts exceptions raised in your API views into response objects. This allows you to control the style of error responses used by your API.
The function must take a single argument, which is the exception to be handled, and should either return a `Response` object, or return `None` if the exception cannot be handled. If the handler returns `None` then the exception will be re-raised and Django will return a standard HTTP 500 'server error' response.
For example, you might want to ensure that all error responses include the HTTP status code in the body of the response, like so:
HTTP/1.1 405 Method Not Allowed
Content-Type: application/json
Content-Length: 62
{"status_code": 405, "detail": "Method 'DELETE' not allowed."}
In order to alter the style of the response, you could write the following custom exception handler:
from rest_framework.views import exception_handler
def custom_exception_handler(exc):
# Call REST framework's default exception handler first,
# to get the standard error response.
response = exception_handler(exc)
# Now add the HTTP status code to the response.
if response is not None:
response.data['status_code'] = response.status_code
return response
The exception handler must also be configured in your settings, using the `EXCEPTION_HANDLER` setting key. For example:
REST_FRAMEWORK = {
'EXCEPTION_HANDLER': 'my_project.my_app.utils.custom_exception_handler'
}
If not specified, the `'EXCEPTION_HANDLER'` setting defaults to the standard exception handler provided by REST framework:
REST_FRAMEWORK = {
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler'
}
Note that the exception handler will only be called for responses generated by raised exceptions. It will not be used for any responses returned directly by the view, such as the `HTTP_400_BAD_REQUEST` responses that are returned by the generic views when serializer validation fails.
--- ---
# API Reference # API Reference
......
...@@ -78,6 +78,9 @@ A generic, **read-only** field. You can use this field for any attribute that d ...@@ -78,6 +78,9 @@ A generic, **read-only** field. You can use this field for any attribute that d
For example, using the following model. For example, using the following model.
from django.db import models
from django.utils.timezone import now
class Account(models.Model): class Account(models.Model):
owner = models.ForeignKey('auth.user') owner = models.ForeignKey('auth.user')
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
...@@ -85,13 +88,14 @@ For example, using the following model. ...@@ -85,13 +88,14 @@ For example, using the following model.
payment_expiry = models.DateTimeField() payment_expiry = models.DateTimeField()
def has_expired(self): def has_expired(self):
now = datetime.datetime.now() return now() > self.payment_expiry
return now > self.payment_expiry
A serializer definition that looked like this: A serializer definition that looked like this:
from rest_framework import serializers
class AccountSerializer(serializers.HyperlinkedModelSerializer): class AccountSerializer(serializers.HyperlinkedModelSerializer):
expired = Field(source='has_expired') expired = serializers.Field(source='has_expired')
class Meta: class Meta:
fields = ('url', 'owner', 'name', 'expired') fields = ('url', 'owner', 'name', 'expired')
...@@ -125,12 +129,11 @@ The `ModelField` class is generally intended for internal use, but can be used b ...@@ -125,12 +129,11 @@ The `ModelField` class is generally intended for internal use, but can be used b
This is a read-only field. It gets its value by calling a method on the serializer class it is attached to. It can be used to add any sort of data to the serialized representation of your object. The field's constructor accepts a single argument, which is the name of the method on the serializer to be called. The method should accept a single argument (in addition to `self`), which is the object being serialized. It should return whatever you want to be included in the serialized representation of the object. For example: This is a read-only field. It gets its value by calling a method on the serializer class it is attached to. It can be used to add any sort of data to the serialized representation of your object. The field's constructor accepts a single argument, which is the name of the method on the serializer to be called. The method should accept a single argument (in addition to `self`), which is the object being serialized. It should return whatever you want to be included in the serialized representation of the object. For example:
from rest_framework import serializers
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.utils.timezone import now from django.utils.timezone import now
from rest_framework import serializers
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
days_since_joined = serializers.SerializerMethodField('get_days_since_joined') days_since_joined = serializers.SerializerMethodField('get_days_since_joined')
class Meta: class Meta:
......
...@@ -20,6 +20,10 @@ You can do so by filtering based on the value of `request.user`. ...@@ -20,6 +20,10 @@ You can do so by filtering based on the value of `request.user`.
For example: For example:
from myapp.models import Purchase
from myapp.serializers import PurchaseSerializer
from rest_framework import generics
class PurchaseList(generics.ListAPIView) class PurchaseList(generics.ListAPIView)
serializer_class = PurchaseSerializer serializer_class = PurchaseSerializer
...@@ -90,6 +94,11 @@ The default filter backends may be set globally, using the `DEFAULT_FILTER_BACKE ...@@ -90,6 +94,11 @@ The default filter backends may be set globally, using the `DEFAULT_FILTER_BACKE
You can also set the filter backends on a per-view, or per-viewset basis, You can also set the filter backends on a per-view, or per-viewset basis,
using the `GenericAPIView` class based views. using the `GenericAPIView` class based views.
from django.contrib.auth.models import User
from myapp.serializers import UserSerializer
from rest_framework import filters
from rest_framework import generics
class UserListView(generics.ListAPIView): class UserListView(generics.ListAPIView):
queryset = User.objects.all() queryset = User.objects.all()
serializer = UserSerializer serializer = UserSerializer
...@@ -150,6 +159,11 @@ This will automatically create a `FilterSet` class for the given fields, and wil ...@@ -150,6 +159,11 @@ This will automatically create a `FilterSet` class for the given fields, and wil
For more advanced filtering requirements you can specify a `FilterSet` class that should be used by the view. For example: For more advanced filtering requirements you can specify a `FilterSet` class that should be used by the view. For example:
import django_filters
from myapp.models import Product
from myapp.serializers import ProductSerializer
from rest_framework import generics
class ProductFilter(django_filters.FilterSet): class ProductFilter(django_filters.FilterSet):
min_price = django_filters.NumberFilter(lookup_type='gte') min_price = django_filters.NumberFilter(lookup_type='gte')
max_price = django_filters.NumberFilter(lookup_type='lte') max_price = django_filters.NumberFilter(lookup_type='lte')
......
...@@ -17,6 +17,11 @@ If the generic views don't suit the needs of your API, you can drop down to usin ...@@ -17,6 +17,11 @@ If the generic views don't suit the needs of your API, you can drop down to usin
Typically when using the generic views, you'll override the view, and set several class attributes. Typically when using the generic views, you'll override the view, and set several class attributes.
from django.contrib.auth.models import User
from myapp.serializers import UserSerializer
from rest_framework import generics
from rest_framework.permissions import IsAdminUser
class UserList(generics.ListCreateAPIView): class UserList(generics.ListCreateAPIView):
queryset = User.objects.all() queryset = User.objects.all()
serializer_class = UserSerializer serializer_class = UserSerializer
...@@ -68,7 +73,7 @@ The following attributes control the basic view behavior. ...@@ -68,7 +73,7 @@ The following attributes control the basic view behavior.
**Pagination**: **Pagination**:
The following attibutes are used to control pagination when used with list views. The following attributes are used to control pagination when used with list views.
* `paginate_by` - The size of pages to use with paginated data. If set to `None` then pagination is turned off. If unset this uses the same value as the `PAGINATE_BY` setting, which defaults to `None`. * `paginate_by` - The size of pages to use with paginated data. If set to `None` then pagination is turned off. If unset this uses the same value as the `PAGINATE_BY` setting, which defaults to `None`.
* `paginate_by_param` - The name of a query parameter, which can be used by the client to override the default page size to use for pagination. If unset this uses the same value as the `PAGINATE_BY_PARAM` setting, which defaults to `None`. * `paginate_by_param` - The name of a query parameter, which can be used by the client to override the default page size to use for pagination. If unset this uses the same value as the `PAGINATE_BY_PARAM` setting, which defaults to `None`.
...@@ -108,7 +113,12 @@ For example: ...@@ -108,7 +113,12 @@ For example:
filter = {} filter = {}
for field in self.multiple_lookup_fields: for field in self.multiple_lookup_fields:
filter[field] = self.kwargs[field] filter[field] = self.kwargs[field]
return get_object_or_404(queryset, **filter)
obj = get_object_or_404(queryset, **filter)
self.check_object_permissions(self.request, obj)
return obj
Note that if your API doesn't include any object level permissions, you may optionally exclude the ``self.check_object_permissions, and simply return the object from the `get_object_or_404` lookup.
#### `get_serializer_class(self)` #### `get_serializer_class(self)`
...@@ -125,7 +135,7 @@ For example: ...@@ -125,7 +135,7 @@ For example:
#### `get_paginate_by(self)` #### `get_paginate_by(self)`
Returns the page size to use with pagination. By default this uses the `paginate_by` attribute, and may be overridden by the cient if the `paginate_by_param` attribute is set. Returns the page size to use with pagination. By default this uses the `paginate_by` attribute, and may be overridden by the client if the `paginate_by_param` attribute is set.
You may want to override this method to provide more complex behavior such as modifying page sizes based on the media type of the response. You may want to override this method to provide more complex behavior such as modifying page sizes based on the media type of the response.
......
...@@ -13,6 +13,7 @@ REST framework includes a `PaginationSerializer` class that makes it easy to ret ...@@ -13,6 +13,7 @@ REST framework includes a `PaginationSerializer` class that makes it easy to ret
Let's start by taking a look at an example from the Django documentation. Let's start by taking a look at an example from the Django documentation.
from django.core.paginator import Paginator from django.core.paginator import Paginator
objects = ['john', 'paul', 'george', 'ringo'] objects = ['john', 'paul', 'george', 'ringo']
paginator = Paginator(objects, 2) paginator = Paginator(objects, 2)
page = paginator.page(1) page = paginator.page(1)
...@@ -22,6 +23,7 @@ Let's start by taking a look at an example from the Django documentation. ...@@ -22,6 +23,7 @@ Let's start by taking a look at an example from the Django documentation.
At this point we've got a page object. If we wanted to return this page object as a JSON response, we'd need to provide the client with context such as next and previous links, so that it would be able to page through the remaining results. At this point we've got a page object. If we wanted to return this page object as a JSON response, we'd need to provide the client with context such as next and previous links, so that it would be able to page through the remaining results.
from rest_framework.pagination import PaginationSerializer from rest_framework.pagination import PaginationSerializer
serializer = PaginationSerializer(instance=page) serializer = PaginationSerializer(instance=page)
serializer.data serializer.data
# {'count': 4, 'next': '?page=2', 'previous': None, 'results': [u'john', u'paul']} # {'count': 4, 'next': '?page=2', 'previous': None, 'results': [u'john', u'paul']}
...@@ -83,11 +85,12 @@ We could now use our pagination serializer in a view like this. ...@@ -83,11 +85,12 @@ We could now use our pagination serializer in a view like this.
The generic class based views `ListAPIView` and `ListCreateAPIView` provide pagination of the returned querysets by default. You can customise this behaviour by altering the pagination style, by modifying the default number of results, by allowing clients to override the page size using a query parameter, or by turning pagination off completely. The generic class based views `ListAPIView` and `ListCreateAPIView` provide pagination of the returned querysets by default. You can customise this behaviour by altering the pagination style, by modifying the default number of results, by allowing clients to override the page size using a query parameter, or by turning pagination off completely.
The default pagination style may be set globally, using the `DEFAULT_PAGINATION_SERIALIZER_CLASS`, `PAGINATE_BY` and `PAGINATE_BY_PARAM` settings. For example. The default pagination style may be set globally, using the `DEFAULT_PAGINATION_SERIALIZER_CLASS`, `PAGINATE_BY`, `PAGINATE_BY_PARAM`, and `MAX_PAGINATE_BY` settings. For example.
REST_FRAMEWORK = { REST_FRAMEWORK = {
'PAGINATE_BY': 10, 'PAGINATE_BY': 10, # Default to 10
'PAGINATE_BY_PARAM': 'page_size' 'PAGINATE_BY_PARAM': 'page_size', # Allow client to override, using `?page_size=xxx`.
'MAX_PAGINATE_BY': 100 # Maximum limit allowed when using `?page_size=xxx`.
} }
You can also set the pagination style on a per-view basis, using the `ListAPIView` generic class-based view. You can also set the pagination style on a per-view basis, using the `ListAPIView` generic class-based view.
...@@ -97,6 +100,7 @@ You can also set the pagination style on a per-view basis, using the `ListAPIVie ...@@ -97,6 +100,7 @@ You can also set the pagination style on a per-view basis, using the `ListAPIVie
serializer_class = ExampleModelSerializer serializer_class = ExampleModelSerializer
paginate_by = 10 paginate_by = 10
paginate_by_param = 'page_size' paginate_by_param = 'page_size'
max_paginate_by = 100
Note that using a `paginate_by` value of `None` will turn off pagination for the view. Note that using a `paginate_by` value of `None` will turn off pagination for the view.
...@@ -114,6 +118,9 @@ You can also override the name used for the object list field, by setting the `r ...@@ -114,6 +118,9 @@ You can also override the name used for the object list field, by setting the `r
For example, to nest a pair of links labelled 'prev' and 'next', and set the name for the results field to 'objects', you might use something like this. For example, to nest a pair of links labelled 'prev' and 'next', and set the name for the results field to 'objects', you might use something like this.
from rest_framework import pagination
from rest_framework import serializers
class LinksSerializer(serializers.Serializer): class LinksSerializer(serializers.Serializer):
next = pagination.NextPageField(source='*') next = pagination.NextPageField(source='*')
prev = pagination.PreviousPageField(source='*') prev = pagination.PreviousPageField(source='*')
...@@ -135,7 +142,7 @@ To have your custom pagination serializer be used by default, use the `DEFAULT_P ...@@ -135,7 +142,7 @@ To have your custom pagination serializer be used by default, use the `DEFAULT_P
Alternatively, to set your custom pagination serializer on a per-view basis, use the `pagination_serializer_class` attribute on a generic class based view: Alternatively, to set your custom pagination serializer on a per-view basis, use the `pagination_serializer_class` attribute on a generic class based view:
class PaginatedListView(ListAPIView): class PaginatedListView(generics.ListAPIView):
model = ExampleModel model = ExampleModel
pagination_serializer_class = CustomPaginationSerializer pagination_serializer_class = CustomPaginationSerializer
paginate_by = 10 paginate_by = 10
......
...@@ -34,9 +34,13 @@ The default set of parsers may be set globally, using the `DEFAULT_PARSER_CLASSE ...@@ -34,9 +34,13 @@ The default set of parsers may be set globally, using the `DEFAULT_PARSER_CLASSE
) )
} }
You can also set the renderers used for an individual view, or viewset, You can also set the parsers used for an individual view, or viewset,
using the `APIView` class based views. using the `APIView` class based views.
from rest_framework.parsers import YAMLParser
from rest_framework.response import Response
from rest_framework.views import APIView
class ExampleView(APIView): class ExampleView(APIView):
""" """
A view that can accept POST requests with YAML content. A view that can accept POST requests with YAML content.
......
...@@ -25,9 +25,17 @@ Object level permissions are run by REST framework's generic views when `.get_ob ...@@ -25,9 +25,17 @@ Object level permissions are run by REST framework's generic views when `.get_ob
As with view level permissions, an `exceptions.PermissionDenied` exception will be raised if the user is not allowed to act on the given object. As with view level permissions, an `exceptions.PermissionDenied` exception will be raised if the user is not allowed to act on the given object.
If you're writing your own views and want to enforce object level permissions, If you're writing your own views and want to enforce object level permissions,
you'll need to explicitly call the `.check_object_permissions(request, obj)` method on the view at the point at which you've retrieved the object. or if you override the `get_object` method on a generic view, then you'll need to explicitly call the `.check_object_permissions(request, obj)` method on the view at the point at which you've retrieved the object.
This will either raise a `PermissionDenied` or `NotAuthenticated` exception, or simply return if the view has the appropriate permissions. This will either raise a `PermissionDenied` or `NotAuthenticated` exception, or simply return if the view has the appropriate permissions.
For example:
def get_object(self):
obj = get_object_or_404(self.get_queryset())
self.check_object_permissions(self.request, obj)
return obj
## Setting the permission policy ## Setting the permission policy
The default permission policy may be set globally, using the `DEFAULT_PERMISSION_CLASSES` setting. For example. The default permission policy may be set globally, using the `DEFAULT_PERMISSION_CLASSES` setting. For example.
...@@ -47,6 +55,10 @@ If not specified, this setting defaults to allowing unrestricted access: ...@@ -47,6 +55,10 @@ If not specified, this setting defaults to allowing unrestricted access:
You can also set the authentication policy on a per-view, or per-viewset basis, You can also set the authentication policy on a per-view, or per-viewset basis,
using the `APIView` class based views. using the `APIView` class based views.
from rest_framework.permissions import IsAuthenticated
from rest_framework.responses import Response
from rest_framework.views import APIView
class ExampleView(APIView): class ExampleView(APIView):
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
...@@ -157,6 +169,8 @@ For more details see the [2.2 release announcement][2.2-announcement]. ...@@ -157,6 +169,8 @@ For more details see the [2.2 release announcement][2.2-announcement].
The following is an example of a permission class that checks the incoming request's IP address against a blacklist, and denies the request if the IP has been blacklisted. The following is an example of a permission class that checks the incoming request's IP address against a blacklist, and denies the request if the IP has been blacklisted.
from rest_framework import permissions
class BlacklistPermission(permissions.BasePermission): class BlacklistPermission(permissions.BasePermission):
""" """
Global permission check for blacklisted IPs. Global permission check for blacklisted IPs.
...@@ -198,6 +212,10 @@ The following third party packages are also available. ...@@ -198,6 +212,10 @@ The following third party packages are also available.
The [DRF Any Permissions][drf-any-permissions] packages provides a different permission behavior in contrast to REST framework. Instead of all specified permissions being required, only one of the given permissions has to be true in order to get access to the view. The [DRF Any Permissions][drf-any-permissions] packages provides a different permission behavior in contrast to REST framework. Instead of all specified permissions being required, only one of the given permissions has to be true in order to get access to the view.
## Composed Permissions
The [Composed Permissions][composed-permissions] package provides a simple way to define complex and multi-depth (with logic operators) permission objects, using small and reusable components.
[cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html [cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html
[authentication]: authentication.md [authentication]: authentication.md
[throttling]: throttling.md [throttling]: throttling.md
...@@ -208,3 +226,4 @@ The [DRF Any Permissions][drf-any-permissions] packages provides a different per ...@@ -208,3 +226,4 @@ The [DRF Any Permissions][drf-any-permissions] packages provides a different per
[2.2-announcement]: ../topics/2.2-announcement.md [2.2-announcement]: ../topics/2.2-announcement.md
[filtering]: filtering.md [filtering]: filtering.md
[drf-any-permissions]: https://github.com/kevin-brown/drf-any-permissions [drf-any-permissions]: https://github.com/kevin-brown/drf-any-permissions
[composed-permissions]: https://github.com/niwibe/djangorestframework-composed-permissions
...@@ -76,7 +76,7 @@ This field is read only. ...@@ -76,7 +76,7 @@ This field is read only.
For example, the following serializer: For example, the following serializer:
class AlbumSerializer(serializers.ModelSerializer): class AlbumSerializer(serializers.ModelSerializer):
tracks = PrimaryKeyRelatedField(many=True, read_only=True) tracks = serializers.PrimaryKeyRelatedField(many=True, read_only=True)
class Meta: class Meta:
model = Album model = Album
...@@ -110,7 +110,7 @@ By default this field is read-write, although you can change this behavior using ...@@ -110,7 +110,7 @@ By default this field is read-write, although you can change this behavior using
For example, the following serializer: For example, the following serializer:
class AlbumSerializer(serializers.ModelSerializer): class AlbumSerializer(serializers.ModelSerializer):
tracks = HyperlinkedRelatedField(many=True, read_only=True, tracks = serializers.HyperlinkedRelatedField(many=True, read_only=True,
view_name='track-detail') view_name='track-detail')
class Meta: class Meta:
...@@ -148,7 +148,8 @@ By default this field is read-write, although you can change this behavior using ...@@ -148,7 +148,8 @@ By default this field is read-write, although you can change this behavior using
For example, the following serializer: For example, the following serializer:
class AlbumSerializer(serializers.ModelSerializer): class AlbumSerializer(serializers.ModelSerializer):
tracks = SlugRelatedField(many=True, read_only=True, slug_field='title') tracks = serializers.SlugRelatedField(many=True, read_only=True,
slug_field='title')
class Meta: class Meta:
model = Album model = Album
...@@ -183,7 +184,7 @@ When using `SlugRelatedField` as a read-write field, you will normally want to e ...@@ -183,7 +184,7 @@ When using `SlugRelatedField` as a read-write field, you will normally want to e
This field can be applied as an identity relationship, such as the `'url'` field on a HyperlinkedModelSerializer. It can also be used for an attribute on the object. For example, the following serializer: This field can be applied as an identity relationship, such as the `'url'` field on a HyperlinkedModelSerializer. It can also be used for an attribute on the object. For example, the following serializer:
class AlbumSerializer(serializers.HyperlinkedModelSerializer): class AlbumSerializer(serializers.HyperlinkedModelSerializer):
track_listing = HyperlinkedIdentityField(view_name='track-list') track_listing = serializers.HyperlinkedIdentityField(view_name='track-list')
class Meta: class Meta:
model = Album model = Album
...@@ -213,8 +214,6 @@ Nested relationships can be expressed by using serializers as fields. ...@@ -213,8 +214,6 @@ Nested relationships can be expressed by using serializers as fields.
If the field is used to represent a to-many relationship, you should add the `many=True` flag to the serializer field. If the field is used to represent a to-many relationship, you should add the `many=True` flag to the serializer field.
Note that nested relationships are currently read-only. For read-write relationships, you should use a flat relational style.
## Example ## Example
For example, the following serializer: For example, the following serializer:
...@@ -422,7 +421,7 @@ For example, if all your object URLs used both a account and a slug in the the U ...@@ -422,7 +421,7 @@ For example, if all your object URLs used both a account and a slug in the the U
def get_object(self, queryset, view_name, view_args, view_kwargs): def get_object(self, queryset, view_name, view_args, view_kwargs):
account = view_kwargs['account'] account = view_kwargs['account']
slug = view_kwargs['slug'] slug = view_kwargs['slug']
return queryset.get(account=account, slug=sug) return queryset.get(account=account, slug=slug)
--- ---
......
...@@ -30,11 +30,16 @@ The default set of renderers may be set globally, using the `DEFAULT_RENDERER_CL ...@@ -30,11 +30,16 @@ The default set of renderers may be set globally, using the `DEFAULT_RENDERER_CL
You can also set the renderers used for an individual view, or viewset, You can also set the renderers used for an individual view, or viewset,
using the `APIView` class based views. using the `APIView` class based views.
from django.contrib.auth.models import User
from rest_framework.renderers import JSONRenderer, YAMLRenderer
from rest_framework.response import Response
from rest_framework.views import APIView
class UserCountView(APIView): class UserCountView(APIView):
""" """
A view that returns the count of active users, in JSON or JSONp. A view that returns the count of active users, in JSON or YAML.
""" """
renderer_classes = (JSONRenderer, JSONPRenderer) renderer_classes = (JSONRenderer, YAMLRenderer)
def get(self, request, format=None): def get(self, request, format=None):
user_count = User.objects.filter(active=True).count() user_count = User.objects.filter(active=True).count()
...@@ -83,7 +88,7 @@ The client may additionally include an `'indent'` media type parameter, in which ...@@ -83,7 +88,7 @@ The client may additionally include an `'indent'` media type parameter, in which
**.format**: `'.json'` **.format**: `'.json'`
**.charset**: `utf-8` **.charset**: `None`
## UnicodeJSONRenderer ## UnicodeJSONRenderer
...@@ -105,7 +110,7 @@ Both the `JSONRenderer` and `UnicodeJSONRenderer` styles conform to [RFC 4627][r ...@@ -105,7 +110,7 @@ Both the `JSONRenderer` and `UnicodeJSONRenderer` styles conform to [RFC 4627][r
**.format**: `'.json'` **.format**: `'.json'`
**.charset**: `utf-8` **.charset**: `None`
## JSONPRenderer ## JSONPRenderer
...@@ -207,6 +212,20 @@ You can use `TemplateHTMLRenderer` either to return regular HTML pages using RES ...@@ -207,6 +212,20 @@ You can use `TemplateHTMLRenderer` either to return regular HTML pages using RES
See also: `TemplateHTMLRenderer` See also: `TemplateHTMLRenderer`
## HTMLFormRenderer
Renders data returned by a serializer into an HTML form. The output of this renderer does not include the enclosing `<form>` tags or an submit actions, as you'll probably need those to include the desired method and URL. Also note that the `HTMLFormRenderer` does not yet support including field error messages.
Note that the template used by the `HTMLFormRenderer` class, and the context submitted to it **may be subject to change**. If you need to use this renderer class it is advised that you either make a local copy of the class and templates, or follow the release note on REST framework upgrades closely.
**.media_type**: `text/html`
**.format**: `'.form'`
**.charset**: `utf-8`
**.template**: `'rest_framework/form.html'`
## BrowsableAPIRenderer ## BrowsableAPIRenderer
Renders data into HTML for the Browsable API. This renderer will determine which other renderer would have been given highest priority, and use that to display an API style response within the HTML page. Renders data into HTML for the Browsable API. This renderer will determine which other renderer would have been given highest priority, and use that to display an API style response within the HTML page.
...@@ -217,6 +236,8 @@ Renders data into HTML for the Browsable API. This renderer will determine whic ...@@ -217,6 +236,8 @@ Renders data into HTML for the Browsable API. This renderer will determine whic
**.charset**: `utf-8` **.charset**: `utf-8`
**.template**: `'rest_framework/api.html'`
#### Customizing BrowsableAPIRenderer #### Customizing BrowsableAPIRenderer
By default the response content will be rendered with the highest priority renderer apart from `BrowseableAPIRenderer`. If you need to customize this behavior, for example to use HTML as the default return format, but use JSON in the browsable API, you can do so by overriding the `get_default_renderer()` method. For example: By default the response content will be rendered with the highest priority renderer apart from `BrowseableAPIRenderer`. If you need to customize this behavior, for example to use HTML as the default return format, but use JSON in the browsable API, you can do so by overriding the `get_default_renderer()` method. For example:
...@@ -290,12 +311,15 @@ By default renderer classes are assumed to be using the `UTF-8` encoding. To us ...@@ -290,12 +311,15 @@ By default renderer classes are assumed to be using the `UTF-8` encoding. To us
Note that if a renderer class returns a unicode string, then the response content will be coerced into a bytestring by the `Response` class, with the `charset` attribute set on the renderer used to determine the encoding. Note that if a renderer class returns a unicode string, then the response content will be coerced into a bytestring by the `Response` class, with the `charset` attribute set on the renderer used to determine the encoding.
If the renderer returns a bytestring representing raw binary content, you should set a charset value of `None`, which will ensure the `Content-Type` header of the response will not have a `charset` value set. Doing so will also ensure that the browsable API will not attempt to display the binary content as a string. If the renderer returns a bytestring representing raw binary content, you should set a charset value of `None`, which will ensure the `Content-Type` header of the response will not have a `charset` value set.
In some cases you may also want to set the `render_style` attribute to `'binary'`. Doing so will also ensure that the browsable API will not attempt to display the binary content as a string.
class JPEGRenderer(renderers.BaseRenderer): class JPEGRenderer(renderers.BaseRenderer):
media_type = 'image/jpeg' media_type = 'image/jpeg'
format = 'jpg' format = 'jpg'
charset = None charset = None
render_style = 'binary'
def render(self, data, media_type=None, renderer_context=None): def render(self, data, media_type=None, renderer_context=None):
return data return data
......
...@@ -117,7 +117,7 @@ For more information see the [browser enhancements documentation]. ...@@ -117,7 +117,7 @@ For more information see the [browser enhancements documentation].
# Standard HttpRequest attributes # Standard HttpRequest attributes
As REST framework's `Request` extends Django's `HttpRequest`, all the other standard attributes and methods are also available. For example the `request.META` dictionary is available as normal. As REST framework's `Request` extends Django's `HttpRequest`, all the other standard attributes and methods are also available. For example the `request.META` and `request.session` dictionaries are available as normal.
Note that due to implementation reasons the `Request` class does not inherit from `HttpRequest` class, but instead extends the class using composition. Note that due to implementation reasons the `Request` class does not inherit from `HttpRequest` class, but instead extends the class using composition.
......
...@@ -27,13 +27,13 @@ Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except t ...@@ -27,13 +27,13 @@ Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except t
You should **include the request as a keyword argument** to the function, for example: You should **include the request as a keyword argument** to the function, for example:
import datetime
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.views import APIView from rest_framework.views import APIView
from django.utils.timezone import now
class APIRootView(APIView): class APIRootView(APIView):
def get(self, request): def get(self, request):
year = datetime.datetime.now().year year = now().year
data = { data = {
... ...
'year-summary-url': reverse('year-summary', args=[year], request=request) 'year-summary-url': reverse('year-summary', args=[year], request=request)
......
...@@ -14,6 +14,8 @@ REST framework adds support for automatic URL routing to Django, and provides yo ...@@ -14,6 +14,8 @@ REST framework adds support for automatic URL routing to Django, and provides yo
Here's an example of a simple URL conf, that uses `DefaultRouter`. Here's an example of a simple URL conf, that uses `DefaultRouter`.
from rest_framework import routers
router = routers.SimpleRouter() router = routers.SimpleRouter()
router.register(r'users', UserViewSet) router.register(r'users', UserViewSet)
router.register(r'accounts', AccountViewSet) router.register(r'accounts', AccountViewSet)
...@@ -40,6 +42,9 @@ The example above would generate the following URL patterns: ...@@ -40,6 +42,9 @@ The example above would generate the following URL patterns:
Any methods on the viewset decorated with `@link` or `@action` will also be routed. Any methods on the viewset decorated with `@link` or `@action` will also be routed.
For example, given a method like this on the `UserViewSet` class: For example, given a method like this on the `UserViewSet` class:
from myapp.permissions import IsAdminOrIsSelf
from rest_framework.decorators import action
@action(permission_classes=[IsAdminOrIsSelf]) @action(permission_classes=[IsAdminOrIsSelf])
def set_password(self, request, pk=None): def set_password(self, request, pk=None):
... ...
...@@ -120,6 +125,8 @@ The arguments to the `Route` named tuple are: ...@@ -120,6 +125,8 @@ The arguments to the `Route` named tuple are:
The following example will only route to the `list` and `retrieve` actions, and does not use the trailing slash convention. The following example will only route to the `list` and `retrieve` actions, and does not use the trailing slash convention.
from rest_framework.routers import Route, SimpleRouter
class ReadOnlyRouter(SimpleRouter): class ReadOnlyRouter(SimpleRouter):
""" """
A router for read-only APIs, which doesn't use trailing slashes. A router for read-only APIs, which doesn't use trailing slashes.
......
...@@ -28,6 +28,8 @@ We'll declare a serializer that we can use to serialize and deserialize `Comment ...@@ -28,6 +28,8 @@ We'll declare a serializer that we can use to serialize and deserialize `Comment
Declaring a serializer looks very similar to declaring a form: Declaring a serializer looks very similar to declaring a form:
from rest_framework import serializers
class CommentSerializer(serializers.Serializer): class CommentSerializer(serializers.Serializer):
email = serializers.EmailField() email = serializers.EmailField()
content = serializers.CharField(max_length=200) content = serializers.CharField(max_length=200)
...@@ -59,6 +61,8 @@ We can now use `CommentSerializer` to serialize a comment, or list of comments. ...@@ -59,6 +61,8 @@ We can now use `CommentSerializer` to serialize a comment, or list of comments.
At this point we've translated the model instance into Python native datatypes. To finalise the serialization process we render the data into `json`. At this point we've translated the model instance into Python native datatypes. To finalise the serialization process we render the data into `json`.
from rest_framework.renderers import JSONRenderer
json = JSONRenderer().render(serializer.data) json = JSONRenderer().render(serializer.data)
json json
# '{"email": "leila@example.com", "content": "foo bar", "created": "2012-08-22T16:20:09.822"}' # '{"email": "leila@example.com", "content": "foo bar", "created": "2012-08-22T16:20:09.822"}'
...@@ -67,6 +71,9 @@ At this point we've translated the model instance into Python native datatypes. ...@@ -67,6 +71,9 @@ At this point we've translated the model instance into Python native datatypes.
Deserialization is similar. First we parse a stream into Python native datatypes... Deserialization is similar. First we parse a stream into Python native datatypes...
from StringIO import StringIO
from rest_framework.parsers import JSONParser
stream = StringIO(json) stream = StringIO(json)
data = JSONParser().parse(stream) data = JSONParser().parse(stream)
...@@ -177,7 +184,7 @@ If a nested representation may optionally accept the `None` value you should pas ...@@ -177,7 +184,7 @@ If a nested representation may optionally accept the `None` value you should pas
content = serializers.CharField(max_length=200) content = serializers.CharField(max_length=200)
created = serializers.DateTimeField() created = serializers.DateTimeField()
Similarly if a nested representation should be a list of items, you should the `many=True` flag to the nested serialized. Similarly if a nested representation should be a list of items, you should pass the `many=True` flag to the nested serialized.
class CommentSerializer(serializers.Serializer): class CommentSerializer(serializers.Serializer):
user = UserSerializer(required=False) user = UserSerializer(required=False)
...@@ -185,11 +192,13 @@ Similarly if a nested representation should be a list of items, you should the ` ...@@ -185,11 +192,13 @@ Similarly if a nested representation should be a list of items, you should the `
content = serializers.CharField(max_length=200) content = serializers.CharField(max_length=200)
created = serializers.DateTimeField() created = serializers.DateTimeField()
--- Validation of nested objects will work the same as before. Errors with nested objects will be nested under the field name of the nested object.
**Note**: Nested serializers are only suitable for read-only representations, as there are cases where they would have ambiguous or non-obvious behavior if used when updating instances. For read-write representations you should always use a flat representation, by using one of the `RelatedField` subclasses.
--- serializer = CommentSerializer(comment, data={'user': {'email': 'foobar', 'username': 'doe'}, 'content': 'baz'})
serializer.is_valid()
# False
serializer.errors
# {'user': {'email': [u'Enter a valid e-mail address.']}, 'created': [u'This field is required.']}
## Dealing with multiple objects ## Dealing with multiple objects
...@@ -241,7 +250,7 @@ This allows you to write views that update or create multiple items when a `PUT` ...@@ -241,7 +250,7 @@ This allows you to write views that update or create multiple items when a `PUT`
serializer = BookSerializer(queryset, data=data, many=True) serializer = BookSerializer(queryset, data=data, many=True)
serializer.is_valid() serializer.is_valid()
# True # True
serialize.save() # `.save()` will be called on each updated or newly created instance. serializer.save() # `.save()` will be called on each updated or newly created instance.
By default bulk updates will be limited to updating instances that already exist in the provided queryset. By default bulk updates will be limited to updating instances that already exist in the provided queryset.
...@@ -293,8 +302,7 @@ You can provide arbitrary additional context by passing a `context` argument whe ...@@ -293,8 +302,7 @@ You can provide arbitrary additional context by passing a `context` argument whe
The context dictionary can be used within any serializer field logic, such as a custom `.to_native()` method, by accessing the `self.context` attribute. The context dictionary can be used within any serializer field logic, such as a custom `.to_native()` method, by accessing the `self.context` attribute.
--- -
# ModelSerializer # ModelSerializer
Often you'll want serializer classes that map closely to model definitions. Often you'll want serializer classes that map closely to model definitions.
...@@ -337,6 +345,8 @@ The default `ModelSerializer` uses primary keys for relationships, but you can a ...@@ -337,6 +345,8 @@ The default `ModelSerializer` uses primary keys for relationships, but you can a
The `depth` option should be set to an integer value that indicates the depth of relationships that should be traversed before reverting to a flat representation. The `depth` option should be set to an integer value that indicates the depth of relationships that should be traversed before reverting to a flat representation.
If you want to customize the way the serialization is done (e.g. using `allow_add_remove`) you'll need to define the field yourself.
## Specifying which fields should be read-only ## Specifying which fields should be read-only
You may wish to specify multiple fields as read-only. Instead of adding each field explicitly with the `read_only=True` attribute, you may use the `read_only_fields` Meta option, like so: You may wish to specify multiple fields as read-only. Instead of adding each field explicitly with the `read_only=True` attribute, you may use the `read_only_fields` Meta option, like so:
......
...@@ -127,6 +127,35 @@ Default: `None` ...@@ -127,6 +127,35 @@ Default: `None`
The name of a query parameter, which can be used by the client to override the default page size to use for pagination. If set to `None`, clients may not override the default page size. The name of a query parameter, which can be used by the client to override the default page size to use for pagination. If set to `None`, clients may not override the default page size.
For example, given the following settings:
REST_FRAMEWORK = {
'PAGINATE_BY': 10,
'PAGINATE_BY_PARAM': 'page_size',
}
A client would be able to modify the pagination size by using the `page_size` query parameter. For example:
GET http://example.com/api/accounts?page_size=25
Default: `None`
#### MAX_PAGINATE_BY
The maximum page size to allow when the page size is specified by the client. If set to `None`, then no maximum limit is applied.
For example, given the following settings:
REST_FRAMEWORK = {
'PAGINATE_BY': 10,
'PAGINATE_BY_PARAM': 'page_size',
'MAX_PAGINATE_BY': 100
}
A client request like the following would return a paginated list of up to 100 items.
GET http://example.com/api/accounts?page_size=999
Default: `None` Default: `None`
--- ---
...@@ -274,8 +303,56 @@ Default: `['iso-8601']` ...@@ -274,8 +303,56 @@ Default: `['iso-8601']`
--- ---
## View names and descriptions
**The following settings are used to generate the view names and descriptions, as used in responses to `OPTIONS` requests, and as used in the browsable API.**
#### VIEW_NAME_FUNCTION
A string representing the function that should be used when generating view names.
This should be a function with the following signature:
view_name(cls, suffix=None)
* `cls`: The view class. Typically the name function would inspect the name of the class when generating a descriptive name, by accessing `cls.__name__`.
* `suffix`: The optional suffix used when differentiating individual views in a viewset.
Default: `'rest_framework.views.get_view_name'`
#### VIEW_DESCRIPTION_FUNCTION
A string representing the function that should be used when generating view descriptions.
This setting can be changed to support markup styles other than the default markdown. For example, you can use it to support `rst` markup in your view docstrings being output in the browsable API.
This should be a function with the following signature:
view_description(cls, html=False)
* `cls`: The view class. Typically the description function would inspect the docstring of the class when generating a description, by accessing `cls.__doc__`
* `html`: A boolean indicating if HTML output is required. `True` when used in the browsable API, and `False` when used in generating `OPTIONS` responses.
Default: `'rest_framework.views.get_view_description'`
---
## Miscellaneous settings ## Miscellaneous settings
#### EXCEPTION_HANDLER
A string representing the function that should be used when returning a response for any given exception. If the function returns `None`, a 500 error will be raised.
This setting can be changed to support error responses other than the default `{"detail": "Failure..."}` responses. For example, you can use it to provide API responses like `{"errors": [{"message": "Failure...", "code": ""} ...]}`.
This should be a function with the following signature:
exception_handler(exc)
* `exc`: The exception.
Default: `'rest_framework.views.exception_handler'`
#### FORMAT_SUFFIX_KWARG #### FORMAT_SUFFIX_KWARG
The name of a parameter in the URL conf that may be used to provide a format suffix. The name of a parameter in the URL conf that may be used to provide a format suffix.
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
Using bare status codes in your responses isn't recommended. REST framework includes a set of named constants that you can use to make more code more obvious and readable. Using bare status codes in your responses isn't recommended. REST framework includes a set of named constants that you can use to make more code more obvious and readable.
from rest_framework import status from rest_framework import status
from rest_framework.response import Response
def empty_view(self): def empty_view(self):
content = {'please move along': 'nothing to see here'} content = {'please move along': 'nothing to see here'}
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Testing # Testing
> Code without tests is broken as designed > Code without tests is broken as designed.
> >
> &mdash; [Jacob Kaplan-Moss][cite] > &mdash; [Jacob Kaplan-Moss][cite]
...@@ -16,6 +16,8 @@ Extends [Django's existing `RequestFactory` class][requestfactory]. ...@@ -16,6 +16,8 @@ Extends [Django's existing `RequestFactory` class][requestfactory].
The `APIRequestFactory` class supports an almost identical API to Django's standard `RequestFactory` class. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available. The `APIRequestFactory` class supports an almost identical API to Django's standard `RequestFactory` class. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available.
from rest_framework.test import APIRequestFactory
# Using the standard RequestFactory API to create a form POST request # Using the standard RequestFactory API to create a form POST request
factory = APIRequestFactory() factory = APIRequestFactory()
request = factory.post('/notes/', {'title': 'new idea'}) request = factory.post('/notes/', {'title': 'new idea'})
...@@ -49,6 +51,8 @@ For example, using `APIRequestFactory`, you can make a form PUT request like so: ...@@ -49,6 +51,8 @@ For example, using `APIRequestFactory`, you can make a form PUT request like so:
Using Django's `RequestFactory`, you'd need to explicitly encode the data yourself: Using Django's `RequestFactory`, you'd need to explicitly encode the data yourself:
from django.test.client import encode_multipart, RequestFactory
factory = RequestFactory() factory = RequestFactory()
data = {'title': 'remember to email dave'} data = {'title': 'remember to email dave'}
content = encode_multipart('BoUnDaRyStRiNg', data) content = encode_multipart('BoUnDaRyStRiNg', data)
...@@ -72,6 +76,12 @@ To forcibly authenticate a request, use the `force_authenticate()` method. ...@@ -72,6 +76,12 @@ To forcibly authenticate a request, use the `force_authenticate()` method.
The signature for the method is `force_authenticate(request, user=None, token=None)`. When making the call, either or both of the user and token may be set. The signature for the method is `force_authenticate(request, user=None, token=None)`. When making the call, either or both of the user and token may be set.
For example, when forcibly authenticating using a token, you might do something like the following:
user = User.objects.get(username='olivia')
request = factory.get('/accounts/django-superstars/')
force_authenticate(request, user=user, token=user.token)
--- ---
**Note**: When using `APIRequestFactory`, the object that is returned is Django's standard `HttpRequest`, and not REST framework's `Request` object, which is only generated once the view is called. **Note**: When using `APIRequestFactory`, the object that is returned is Django's standard `HttpRequest`, and not REST framework's `Request` object, which is only generated once the view is called.
...@@ -105,6 +115,8 @@ Extends [Django's existing `Client` class][client]. ...@@ -105,6 +115,8 @@ Extends [Django's existing `Client` class][client].
The `APIClient` class supports the same request interface as `APIRequestFactory`. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available. For example: The `APIClient` class supports the same request interface as `APIRequestFactory`. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available. For example:
from rest_framework.test import APIClient
client = APIClient() client = APIClient()
client.post('/notes/', {'title': 'new idea'}, format='json') client.post('/notes/', {'title': 'new idea'}, format='json')
...@@ -131,8 +143,11 @@ The `login` method is appropriate for testing APIs that use session authenticati ...@@ -131,8 +143,11 @@ The `login` method is appropriate for testing APIs that use session authenticati
The `credentials` method can be used to set headers that will then be included on all subsequent requests by the test client. The `credentials` method can be used to set headers that will then be included on all subsequent requests by the test client.
from rest_framework.authtoken.models import Token
from rest_framework.test import APIClient
# Include an appropriate `Authorization:` header on all requests. # Include an appropriate `Authorization:` header on all requests.
token = Token.objects.get(username='lauren') token = Token.objects.get(user__username='lauren')
client = APIClient() client = APIClient()
client.credentials(HTTP_AUTHORIZATION='Token ' + token.key) client.credentials(HTTP_AUTHORIZATION='Token ' + token.key)
...@@ -190,10 +205,10 @@ You can use any of REST framework's test case classes as you would for the regul ...@@ -190,10 +205,10 @@ You can use any of REST framework's test case classes as you would for the regul
Ensure we can create a new account object. Ensure we can create a new account object.
""" """
url = reverse('account-list') url = reverse('account-list')
data = {'name': 'DabApps'} expected = {'name': 'DabApps'}
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, data) self.assertEqual(response.data, expected)
--- ---
......
...@@ -43,6 +43,10 @@ The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `mi ...@@ -43,6 +43,10 @@ The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `mi
You can also set the throttling policy on a per-view or per-viewset basis, You can also set the throttling policy on a per-view or per-viewset basis,
using the `APIView` class based views. using the `APIView` class based views.
from rest_framework.response import Response
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
class ExampleView(APIView): class ExampleView(APIView):
throttle_classes = (UserRateThrottle,) throttle_classes = (UserRateThrottle,)
...@@ -66,6 +70,13 @@ Or, if you're using the `@api_view` decorator with function based views. ...@@ -66,6 +70,13 @@ Or, if you're using the `@api_view` decorator with function based views.
The throttle classes provided by REST framework use Django's cache backend. You should make sure that you've set appropriate [cache settings][cache-setting]. The default value of `LocMemCache` backend should be okay for simple setups. See Django's [cache documentation][cache-docs] for more details. The throttle classes provided by REST framework use Django's cache backend. You should make sure that you've set appropriate [cache settings][cache-setting]. The default value of `LocMemCache` backend should be okay for simple setups. See Django's [cache documentation][cache-docs] for more details.
If you need to use a cache other than `'default'`, you can do so by creating a custom throttle class and setting the `cache` attribute. For example:
class CustomAnonRateThrottle(AnonRateThrottle):
cache = get_cache('alternate')
You'll need to rememeber to also set your custom throttle class in the `'DEFAULT_THROTTLE_CLASSES'` settings key, or using the `throttle_classes` view attribute.
--- ---
# API Reference # API Reference
......
...@@ -19,6 +19,12 @@ Typically, rather than explicitly registering the views in a viewset in the urlc ...@@ -19,6 +19,12 @@ Typically, rather than explicitly registering the views in a viewset in the urlc
Let's define a simple viewset that can be used to list or retrieve all the users in the system. Let's define a simple viewset that can be used to list or retrieve all the users in the system.
from django.contrib.auth.models import User
from django.shortcuts import get_object_or_404
from myapps.serializers import UserSerializer
from rest_framework import viewsets
from rest_framewor.responses import Response
class UserViewSet(viewsets.ViewSet): class UserViewSet(viewsets.ViewSet):
""" """
A simple ViewSet that for listing or retrieving users. A simple ViewSet that for listing or retrieving users.
...@@ -41,6 +47,9 @@ If we need to, we can bind this viewset into two separate views, like so: ...@@ -41,6 +47,9 @@ If we need to, we can bind this viewset into two separate views, like so:
Typically we wouldn't do this, but would instead register the viewset with a router, and allow the urlconf to be automatically generated. Typically we wouldn't do this, but would instead register the viewset with a router, and allow the urlconf to be automatically generated.
from myapp.views import UserViewSet
from rest_framework.routers import DefaultRouter
router = DefaultRouter() router = DefaultRouter()
router.register(r'users', UserViewSet) router.register(r'users', UserViewSet)
urlpatterns = router.urls urlpatterns = router.urls
...@@ -133,6 +142,10 @@ The `@action` decorator will route `POST` requests by default, but may also acce ...@@ -133,6 +142,10 @@ The `@action` decorator will route `POST` requests by default, but may also acce
@action(methods=['POST', 'DELETE']) @action(methods=['POST', 'DELETE'])
def unset_password(self, request, pk=None): def unset_password(self, request, pk=None):
... ...
The two new actions will then be available at the urls `^users/{pk}/set_password/$` and `^users/{pk}/unset_password/$`
--- ---
# API Reference # API Reference
......
...@@ -200,7 +200,7 @@ To run the tests against all supported configurations, first install [the tox te ...@@ -200,7 +200,7 @@ To run the tests against all supported configurations, first install [the tox te
## Support ## Support
For support please see the [REST framework discussion group][group], try the `#restframework` channel on `irc.freenode.net`, or raise a question on [Stack Overflow][stack-overflow], making sure to include the ['django-rest-framework'][django-rest-framework-tag] tag. For support please see the [REST framework discussion group][group], try the `#restframework` channel on `irc.freenode.net`, search [the IRC archives][botbot], or raise a question on [Stack Overflow][stack-overflow], making sure to include the ['django-rest-framework'][django-rest-framework-tag] tag.
[Paid support is available][paid-support] from [DabApps][dabapps], and can include work on REST framework core, or support with building your REST framework API. Please [contact DabApps][contact-dabapps] if you'd like to discuss commercial support options. [Paid support is available][paid-support] from [DabApps][dabapps], and can include work on REST framework core, or support with building your REST framework API. Please [contact DabApps][contact-dabapps] if you'd like to discuss commercial support options.
...@@ -307,6 +307,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -307,6 +307,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
[tox]: http://testrun.org/tox/latest/ [tox]: http://testrun.org/tox/latest/
[group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework [group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework
[botbot]: https://botbot.me/freenode/restframework/
[stack-overflow]: http://stackoverflow.com/ [stack-overflow]: http://stackoverflow.com/
[django-rest-framework-tag]: http://stackoverflow.com/questions/tagged/django-rest-framework [django-rest-framework-tag]: http://stackoverflow.com/questions/tagged/django-rest-framework
[django-tag]: http://stackoverflow.com/questions/tagged/django [django-tag]: http://stackoverflow.com/questions/tagged/django
......
...@@ -157,6 +157,16 @@ The following people have helped make REST framework great. ...@@ -157,6 +157,16 @@ The following people have helped make REST framework great.
* Dan Stephenson - [etos] * Dan Stephenson - [etos]
* Martin Clement - [martync] * Martin Clement - [martync]
* Jeremy Satterfield - [jsatt] * Jeremy Satterfield - [jsatt]
* Christopher Paolini - [chrispaolini]
* Filipe A Ximenes - [filipeximenes]
* Ramiro Morales - [ramiro]
* Krzysztof Jurewicz - [krzysiekj]
* Eric Buehl - [ericbuehl]
* Kristian Øllegaard - [kristianoellegaard]
* Alexander Akhmetov - [alexander-akhmetov]
* Andrey Antukh - [niwibe]
* Mathieu Pillard - [diox]
* Edmond Wong - [edmondwong]
Many thanks to everyone who's contributed to the project. Many thanks to everyone who's contributed to the project.
...@@ -350,3 +360,13 @@ You can also contact [@_tomchristie][twitter] directly on twitter. ...@@ -350,3 +360,13 @@ You can also contact [@_tomchristie][twitter] directly on twitter.
[etos]: https://github.com/etos [etos]: https://github.com/etos
[martync]: https://github.com/martync [martync]: https://github.com/martync
[jsatt]: https://github.com/jsatt [jsatt]: https://github.com/jsatt
[chrispaolini]: https://github.com/chrispaolini
[filipeximenes]: https://github.com/filipeximenes
[ramiro]: https://github.com/ramiro
[krzysiekj]: https://github.com/krzysiekj
[ericbuehl]: https://github.com/ericbuehl
[kristianoellegaard]: https://github.com/kristianoellegaard
[alexander-akhmetov]: https://github.com/alexander-akhmetov
[niwibe]: https://github.com/niwibe
[diox]: https://github.com/diox
[edmondwong]: https://github.com/edmondwong
...@@ -40,6 +40,19 @@ You can determine your currently installed version using `pip freeze`: ...@@ -40,6 +40,19 @@ You can determine your currently installed version using `pip freeze`:
## 2.3.x series ## 2.3.x series
### Master
* Support customizable exception handling, using the `EXCEPTION_HANDLER` setting.
* Support customizable view name and description functions, using the `VIEW_NAME_FUNCTION` and `VIEW_DESCRIPTION_FUNCTION` settings.
* Added `MAX_PAGINATE_BY` setting and `max_paginate_by` generic view attribute.
* Added `cache` attribute to throttles to allow overriding of default cache.
* 'Raw data' tab in browsable API now contains pre-populated data.
* 'Raw data' and 'HTML form' tab preference in browseable API now saved between page views.
* Bugfix: `required=True` argument fixed for boolean serializer fields.
* Bugfix: `client.force_authenticate(None)` should also clear session info if it exists.
* Bugfix: Client sending emptry string instead of file now clears `FileField`.
* Bugfix: Empty values on ChoiceFields with `required=False` now consistently return `None`.
### 2.3.7 ### 2.3.7
**Date**: 16th August 2013 **Date**: 16th August 2013
......
...@@ -61,6 +61,7 @@ To see what's going on under the hood let's first explicitly create a set of vie ...@@ -61,6 +61,7 @@ To see what's going on under the hood let's first explicitly create a set of vie
In the `urls.py` file we bind our `ViewSet` classes into a set of concrete views. In the `urls.py` file we bind our `ViewSet` classes into a set of concrete views.
from snippets.views import SnippetViewSet, UserViewSet from snippets.views import SnippetViewSet, UserViewSet
from rest_framework import renderers
snippet_list = SnippetViewSet.as_view({ snippet_list = SnippetViewSet.as_view({
'get': 'list', 'get': 'list',
...@@ -101,6 +102,7 @@ Because we're using `ViewSet` classes rather than `View` classes, we actually do ...@@ -101,6 +102,7 @@ Because we're using `ViewSet` classes rather than `View` classes, we actually do
Here's our re-wired `urls.py` file. Here's our re-wired `urls.py` file.
from django.conf.urls import patterns, url, include
from snippets import views from snippets import views
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
......
...@@ -12,7 +12,7 @@ Create a new Django project named `tutorial`, then start a new app called `quick ...@@ -12,7 +12,7 @@ Create a new Django project named `tutorial`, then start a new app called `quick
# Create a virtualenv to isolate our package dependencies locally # Create a virtualenv to isolate our package dependencies locally
virtualenv env virtualenv env
source env/bin/activate source env/bin/activate # On Windows use `env\Scripts\activate`
# Install Django and Django REST framework into the virtualenv # Install Django and Django REST framework into the virtualenv
pip install django pip install django
......
...@@ -16,6 +16,7 @@ from django.core import validators ...@@ -16,6 +16,7 @@ from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.conf import settings from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH from django.db.models.fields import BLANK_CHOICE_DASH
from django.http import QueryDict
from django.forms import widgets from django.forms import widgets
from django.utils.encoding import is_protected_type from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
...@@ -307,7 +308,10 @@ class WritableField(Field): ...@@ -307,7 +308,10 @@ class WritableField(Field):
try: try:
if self.use_files: if self.use_files:
files = files or {} files = files or {}
try:
native = files[field_name] native = files[field_name]
except KeyError:
native = data[field_name]
else: else:
native = data[field_name] native = data[field_name]
except KeyError: except KeyError:
...@@ -399,10 +403,15 @@ class BooleanField(WritableField): ...@@ -399,10 +403,15 @@ class BooleanField(WritableField):
} }
empty = False empty = False
# Note: we set default to `False` in order to fill in missing value not def field_from_native(self, data, files, field_name, into):
# supplied by html form. TODO: Fix so that only html form input gets # HTML checkboxes do not explicitly represent unchecked as `False`
# this behavior. # we deal with that here...
default = False if isinstance(data, QueryDict):
self.default = False
return super(BooleanField, self).field_from_native(
data, files, field_name, into
)
def from_native(self, value): def from_native(self, value):
if value in ('true', 't', 'True', '1'): if value in ('true', 't', 'True', '1'):
...@@ -505,6 +514,11 @@ class ChoiceField(WritableField): ...@@ -505,6 +514,11 @@ class ChoiceField(WritableField):
return True return True
return False return False
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
return super(ChoiceField, self).from_native(value)
class EmailField(CharField): class EmailField(CharField):
type_name = 'EmailField' type_name = 'EmailField'
......
...@@ -14,6 +14,17 @@ from rest_framework.settings import api_settings ...@@ -14,6 +14,17 @@ from rest_framework.settings import api_settings
import warnings import warnings
def strict_positive_int(integer_string, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret <= 0:
raise ValueError()
if cutoff:
ret = min(ret, cutoff)
return ret
def get_object_or_404(queryset, **filter_kwargs): def get_object_or_404(queryset, **filter_kwargs):
""" """
Same as Django's standard shortcut, but make sure to raise 404 Same as Django's standard shortcut, but make sure to raise 404
...@@ -47,6 +58,7 @@ class GenericAPIView(views.APIView): ...@@ -47,6 +58,7 @@ class GenericAPIView(views.APIView):
# Pagination settings # Pagination settings
paginate_by = api_settings.PAGINATE_BY paginate_by = api_settings.PAGINATE_BY
paginate_by_param = api_settings.PAGINATE_BY_PARAM paginate_by_param = api_settings.PAGINATE_BY_PARAM
max_paginate_by = api_settings.MAX_PAGINATE_BY
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
page_kwarg = 'page' page_kwarg = 'page'
...@@ -135,7 +147,7 @@ class GenericAPIView(views.APIView): ...@@ -135,7 +147,7 @@ class GenericAPIView(views.APIView):
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1 page = page_kwarg or page_query_param or 1
try: try:
page_number = int(page) page_number = strict_positive_int(page)
except ValueError: except ValueError:
if page == 'last': if page == 'last':
page_number = paginator.num_pages page_number = paginator.num_pages
...@@ -196,9 +208,11 @@ class GenericAPIView(views.APIView): ...@@ -196,9 +208,11 @@ class GenericAPIView(views.APIView):
PendingDeprecationWarning, stacklevel=2) PendingDeprecationWarning, stacklevel=2)
if self.paginate_by_param: if self.paginate_by_param:
query_params = self.request.QUERY_PARAMS
try: try:
return int(query_params[self.paginate_by_param]) return strict_positive_int(
self.request.QUERY_PARAMS[self.paginate_by_param],
cutoff=self.max_paginate_by
)
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
...@@ -342,8 +356,15 @@ class GenericAPIView(views.APIView): ...@@ -342,8 +356,15 @@ class GenericAPIView(views.APIView):
self.check_permissions(cloned_request) self.check_permissions(cloned_request)
# Test object permissions # Test object permissions
if method == 'PUT': if method == 'PUT':
try:
self.get_object() self.get_object()
except (exceptions.APIException, PermissionDenied, Http404): except Http404:
# Http404 should be acceptable and the serializer
# metadata should be populated. Except this so the
# outer "else" clause of the try-except-else block
# will be executed.
pass
except (exceptions.APIException, PermissionDenied):
pass pass
else: else:
# If user has appropriate permissions for the view, include # If user has appropriate permissions for the view, include
......
...@@ -142,11 +142,16 @@ class UpdateModelMixin(object): ...@@ -142,11 +142,16 @@ class UpdateModelMixin(object):
try: try:
return self.get_object() return self.get_object()
except Http404: except Http404:
# If this is a PUT-as-create operation, we need to ensure that if self.request.method == 'PUT':
# we have relevant permissions, as if this was a POST request. # For PUT-as-create operation, we need to ensure that we have
# This will either raise a PermissionDenied exception, # relevant permissions, as if this was a POST request. This
# or simply return None # will either raise a PermissionDenied exception, or simply
# return None.
self.check_permissions(clone_request(self.request, 'POST')) self.check_permissions(clone_request(self.request, 'POST'))
else:
# PATCH requests where the object does not exist should still
# return a 404 response.
raise
def pre_save(self, obj): def pre_save(self, obj):
""" """
......
...@@ -10,9 +10,9 @@ from django.core.files.uploadhandler import StopFutureHandlers ...@@ -10,9 +10,9 @@ from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
from rest_framework.compat import yaml, etree from rest_framework.compat import etree, six, yaml
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.compat import six from rest_framework import renderers
import json import json
import datetime import datetime
import decimal import decimal
...@@ -47,6 +47,7 @@ class JSONParser(BaseParser): ...@@ -47,6 +47,7 @@ class JSONParser(BaseParser):
""" """
media_type = 'application/json' media_type = 'application/json'
renderer_class = renderers.UnicodeJSONRenderer
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
""" """
...@@ -121,7 +122,8 @@ class MultiPartParser(BaseParser): ...@@ -121,7 +122,8 @@ class MultiPartParser(BaseParser):
parser_context = parser_context or {} parser_context = parser_context or {}
request = parser_context['request'] request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META meta = request.META.copy()
meta['CONTENT_TYPE'] = media_type
upload_handlers = request.upload_handlers upload_handlers = request.upload_handlers
try: try:
...@@ -129,7 +131,7 @@ class MultiPartParser(BaseParser): ...@@ -129,7 +131,7 @@ class MultiPartParser(BaseParser):
data, files = parser.parse() data, files = parser.parse()
return DataAndFiles(data, files) return DataAndFiles(data, files)
except MultiPartParserError as exc: except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % six.u(exc)) raise ParseError('Multipart form parse error - %s' % str(exc))
class XMLParser(BaseParser): class XMLParser(BaseParser):
......
...@@ -134,9 +134,9 @@ class RelatedField(WritableField): ...@@ -134,9 +134,9 @@ class RelatedField(WritableField):
value = obj value = obj
for component in source.split('.'): for component in source.split('.'):
value = get_component(value, component)
if value is None: if value is None:
break break
value = get_component(value, component)
except ObjectDoesNotExist: except ObjectDoesNotExist:
return None return None
...@@ -244,6 +244,8 @@ class PrimaryKeyRelatedField(RelatedField): ...@@ -244,6 +244,8 @@ class PrimaryKeyRelatedField(RelatedField):
source = self.source or field_name source = self.source or field_name
queryset = obj queryset = obj
for component in source.split('.'): for component in source.split('.'):
if queryset is None:
return []
queryset = get_component(queryset, component) queryset = get_component(queryset, component)
# Forward relationship # Forward relationship
...@@ -262,7 +264,7 @@ class PrimaryKeyRelatedField(RelatedField): ...@@ -262,7 +264,7 @@ class PrimaryKeyRelatedField(RelatedField):
# RelatedObject (reverse relationship) # RelatedObject (reverse relationship)
try: try:
pk = getattr(obj, self.source or field_name).pk pk = getattr(obj, self.source or field_name).pk
except ObjectDoesNotExist: except (ObjectDoesNotExist, AttributeError):
return None return None
# Forward relationship # Forward relationship
...@@ -567,8 +569,13 @@ class HyperlinkedIdentityField(Field): ...@@ -567,8 +569,13 @@ class HyperlinkedIdentityField(Field):
May raise a `NoReverseMatch` if the `view_name` and `lookup_field` May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf. attributes are not configured to correctly match the URL conf.
""" """
lookup_field = getattr(obj, self.lookup_field) lookup_field = getattr(obj, self.lookup_field, None)
kwargs = {self.lookup_field: lookup_field} kwargs = {self.lookup_field: lookup_field}
# Handle unsaved object case
if lookup_field is None:
return None
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch: except NoReverseMatch:
......
...@@ -28,6 +28,29 @@ def is_form_media_type(media_type): ...@@ -28,6 +28,29 @@ def is_form_media_type(media_type):
base_media_type == 'multipart/form-data') base_media_type == 'multipart/form-data')
class override_method(object):
"""
A context manager that temporarily overrides the method on a request,
additionally setting the `view.request` attribute.
Usage:
with override_method(view, request, 'POST') as request:
... # Do stuff with `view` and `request`
"""
def __init__(self, view, request, method):
self.view = view
self.request = request
self.method = method
def __enter__(self):
self.view.request = clone_request(self.request, self.method)
return self.view.request
def __exit__(self, *args, **kwarg):
self.view.request = self.request
class Empty(object): class Empty(object):
""" """
Placeholder for unset attributes. Placeholder for unset attributes.
......
...@@ -189,7 +189,11 @@ class SimpleRouter(BaseRouter): ...@@ -189,7 +189,11 @@ class SimpleRouter(BaseRouter):
Given a viewset, return the portion of URL regex that is used Given a viewset, return the portion of URL regex that is used
to match against a single instance. to match against a single instance.
""" """
if self.trailing_slash:
base_regex = '(?P<{lookup_field}>[^/]+)' base_regex = '(?P<{lookup_field}>[^/]+)'
else:
# Don't consume `.json` style suffixes
base_regex = '(?P<{lookup_field}>[^/.]+)'
lookup_field = getattr(viewset, 'lookup_field', 'pk') lookup_field = getattr(viewset, 'lookup_field', 'pk')
return base_regex.format(lookup_field=lookup_field) return base_regex.format(lookup_field=lookup_field)
......
...@@ -32,6 +32,9 @@ from rest_framework.relations import * ...@@ -32,6 +32,9 @@ from rest_framework.relations import *
from rest_framework.fields import * from rest_framework.fields import *
class RelationsList(list):
_deleted = []
class NestedValidationError(ValidationError): class NestedValidationError(ValidationError):
""" """
The default ValidationError behavior is to stringify each item in the list The default ValidationError behavior is to stringify each item in the list
...@@ -161,7 +164,6 @@ class BaseSerializer(WritableField): ...@@ -161,7 +164,6 @@ class BaseSerializer(WritableField):
self._data = None self._data = None
self._files = None self._files = None
self._errors = None self._errors = None
self._deleted = None
if many and instance is not None and not hasattr(instance, '__iter__'): if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True') raise ValueError('instance should be a queryset or other iterable with many=True')
...@@ -298,7 +300,8 @@ class BaseSerializer(WritableField): ...@@ -298,7 +300,8 @@ class BaseSerializer(WritableField):
Serialize objects -> primitives. Serialize objects -> primitives.
""" """
ret = self._dict_class() ret = self._dict_class()
ret.fields = {} ret.fields = self._dict_class()
ret.empty = obj is None
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name) field.initialize(parent=self, field_name=field_name)
...@@ -331,14 +334,15 @@ class BaseSerializer(WritableField): ...@@ -331,14 +334,15 @@ class BaseSerializer(WritableField):
if self.source == '*': if self.source == '*':
return self.to_native(obj) return self.to_native(obj)
# Get the raw field value
try: try:
source = self.source or field_name source = self.source or field_name
value = obj value = obj
for component in source.split('.'): for component in source.split('.'):
value = get_component(value, component)
if value is None: if value is None:
break break
value = get_component(value, component)
except ObjectDoesNotExist: except ObjectDoesNotExist:
return None return None
...@@ -378,6 +382,7 @@ class BaseSerializer(WritableField): ...@@ -378,6 +382,7 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists # Set the serializer object if it exists
obj = getattr(self.parent.object, field_name) if self.parent.object else None obj = getattr(self.parent.object, field_name) if self.parent.object else None
obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if self.source == '*': if self.source == '*':
if value: if value:
...@@ -391,7 +396,8 @@ class BaseSerializer(WritableField): ...@@ -391,7 +396,8 @@ class BaseSerializer(WritableField):
'data': value, 'data': value,
'context': self.context, 'context': self.context,
'partial': self.partial, 'partial': self.partial,
'many': self.many 'many': self.many,
'allow_add_remove': self.allow_add_remove
} }
serializer = self.__class__(**kwargs) serializer = self.__class__(**kwargs)
...@@ -434,7 +440,7 @@ class BaseSerializer(WritableField): ...@@ -434,7 +440,7 @@ class BaseSerializer(WritableField):
DeprecationWarning, stacklevel=3) DeprecationWarning, stacklevel=3)
if many: if many:
ret = [] ret = RelationsList()
errors = [] errors = []
update = self.object is not None update = self.object is not None
...@@ -461,8 +467,8 @@ class BaseSerializer(WritableField): ...@@ -461,8 +467,8 @@ class BaseSerializer(WritableField):
ret.append(self.from_native(item, None)) ret.append(self.from_native(item, None))
errors.append(self._errors) errors.append(self._errors)
if update: if update and self.allow_add_remove:
self._deleted = identity_to_objects.values() ret._deleted = identity_to_objects.values()
self._errors = any(errors) and errors or [] self._errors = any(errors) and errors or []
else: else:
...@@ -514,12 +520,12 @@ class BaseSerializer(WritableField): ...@@ -514,12 +520,12 @@ class BaseSerializer(WritableField):
""" """
if isinstance(self.object, list): if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object] [self.save_object(item, **kwargs) for item in self.object]
if self.object._deleted:
[self.delete_object(item) for item in self.object._deleted]
else: else:
self.save_object(self.object, **kwargs) self.save_object(self.object, **kwargs)
if self.allow_add_remove and self._deleted:
[self.delete_object(item) for item in self._deleted]
return self.object return self.object
def metadata(self): def metadata(self):
...@@ -795,9 +801,12 @@ class ModelSerializer(Serializer): ...@@ -795,9 +801,12 @@ class ModelSerializer(Serializer):
cls = self.opts.model cls = self.opts.model
opts = get_concrete_model(cls)._meta opts = get_concrete_model(cls)._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many] exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
field_name = field.source or field_name field_name = field.source or field_name
if field_name in exclusions and not field.read_only: if field_name in exclusions \
and not field.read_only \
and not isinstance(field, Serializer):
exclusions.remove(field_name) exclusions.remove(field_name)
return exclusions return exclusions
...@@ -823,6 +832,7 @@ class ModelSerializer(Serializer): ...@@ -823,6 +832,7 @@ class ModelSerializer(Serializer):
""" """
m2m_data = {} m2m_data = {}
related_data = {} related_data = {}
nested_forward_relations = {}
meta = self.opts.model._meta meta = self.opts.model._meta
# Reverse fk or one-to-one relations # Reverse fk or one-to-one relations
...@@ -842,6 +852,12 @@ class ModelSerializer(Serializer): ...@@ -842,6 +852,12 @@ class ModelSerializer(Serializer):
if field.name in attrs: if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name) m2m_data[field.name] = attrs.pop(field.name)
# Nested forward relations - These need to be marked so we can save
# them before saving the parent model instance.
for field_name in attrs.keys():
if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name]
# Update an existing instance... # Update an existing instance...
if instance is not None: if instance is not None:
for key, val in attrs.items(): for key, val in attrs.items():
...@@ -857,6 +873,7 @@ class ModelSerializer(Serializer): ...@@ -857,6 +873,7 @@ class ModelSerializer(Serializer):
# at the point of save. # at the point of save.
instance._related_data = related_data instance._related_data = related_data
instance._m2m_data = m2m_data instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations
return instance return instance
...@@ -872,6 +889,14 @@ class ModelSerializer(Serializer): ...@@ -872,6 +889,14 @@ class ModelSerializer(Serializer):
""" """
Save the deserialized object and return it. Save the deserialized object and return it.
""" """
if getattr(obj, '_nested_forward_relations', None):
# Nested relationships need to be saved before we can save the
# parent instance.
for field_name, sub_object in obj._nested_forward_relations.items():
if sub_object:
self.save_object(sub_object)
setattr(obj, field_name, sub_object)
obj.save(**kwargs) obj.save(**kwargs)
if getattr(obj, '_m2m_data', None): if getattr(obj, '_m2m_data', None):
...@@ -881,6 +906,24 @@ class ModelSerializer(Serializer): ...@@ -881,6 +906,24 @@ class ModelSerializer(Serializer):
if getattr(obj, '_related_data', None): if getattr(obj, '_related_data', None):
for accessor_name, related in obj._related_data.items(): for accessor_name, related in obj._related_data.items():
if isinstance(related, RelationsList):
# Nested reverse fk relationship
for related_item in related:
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
setattr(related_item, fk_field, obj)
self.save_object(related_item)
# Delete any removed objects
if related._deleted:
[self.delete_object(item) for item in related._deleted]
elif isinstance(related, models.Model):
# Nested reverse one-one relationship
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
setattr(related, fk_field, obj)
self.save_object(related)
else:
# Reverse FK or reverse one-one
setattr(obj, accessor_name, related) setattr(obj, accessor_name, related)
del(obj._related_data) del(obj._related_data)
...@@ -903,6 +946,7 @@ class HyperlinkedModelSerializer(ModelSerializer): ...@@ -903,6 +946,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions _options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail' _default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField _hyperlink_field_class = HyperlinkedRelatedField
_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self): def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields() fields = super(HyperlinkedModelSerializer, self).get_default_fields()
...@@ -911,7 +955,7 @@ class HyperlinkedModelSerializer(ModelSerializer): ...@@ -911,7 +955,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
self.opts.view_name = self._get_default_view_name(self.opts.model) self.opts.view_name = self._get_default_view_name(self.opts.model)
if 'url' not in fields: if 'url' not in fields:
url_field = HyperlinkedIdentityField( url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name, view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field lookup_field=self.opts.lookup_field
) )
......
...@@ -48,7 +48,6 @@ DEFAULTS = { ...@@ -48,7 +48,6 @@ DEFAULTS = {
), ),
'DEFAULT_THROTTLE_CLASSES': ( 'DEFAULT_THROTTLE_CLASSES': (
), ),
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation', 'rest_framework.negotiation.DefaultContentNegotiation',
...@@ -68,11 +67,19 @@ DEFAULTS = { ...@@ -68,11 +67,19 @@ DEFAULTS = {
# Pagination # Pagination
'PAGINATE_BY': None, 'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None, 'PAGINATE_BY_PARAM': None,
'MAX_PAGINATE_BY': None,
# Authentication # Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None, 'UNAUTHENTICATED_TOKEN': None,
# View configuration
'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
# Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
# Testing # Testing
'TEST_REQUEST_RENDERER_CLASSES': ( 'TEST_REQUEST_RENDERER_CLASSES': (
'rest_framework.renderers.MultiPartRenderer', 'rest_framework.renderers.MultiPartRenderer',
...@@ -121,10 +128,13 @@ IMPORT_STRINGS = ( ...@@ -121,10 +128,13 @@ IMPORT_STRINGS = (
'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS', 'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER',
'FILTER_BACKEND', 'FILTER_BACKEND',
'TEST_REQUEST_RENDERER_CLASSES', 'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN', 'UNAUTHENTICATED_TOKEN',
'VIEW_NAME_FUNCTION',
'VIEW_DESCRIPTION_FUNCTION'
) )
......
function getCookie(c_name)
{
// From http://www.w3schools.com/js/js_cookies.asp
var c_value = document.cookie;
var c_start = c_value.indexOf(" " + c_name + "=");
if (c_start == -1) {
c_start = c_value.indexOf(c_name + "=");
}
if (c_start == -1) {
c_value = null;
} else {
c_start = c_value.indexOf("=", c_start) + 1;
var c_end = c_value.indexOf(";", c_start);
if (c_end == -1) {
c_end = c_value.length;
}
c_value = unescape(c_value.substring(c_start,c_end));
}
return c_value;
}
// JSON highlighting.
prettyPrint(); prettyPrint();
// Bootstrap tooltips.
$('.js-tooltip').tooltip({ $('.js-tooltip').tooltip({
delay: 1000 delay: 1000
}); });
// Deal with rounded tab styling after tab clicks.
$('a[data-toggle="tab"]:first').on('shown', function (e) { $('a[data-toggle="tab"]:first').on('shown', function (e) {
$(e.target).parents('.tabbable').addClass('first-tab-active'); $(e.target).parents('.tabbable').addClass('first-tab-active');
}); });
$('a[data-toggle="tab"]:not(:first)').on('shown', function (e) { $('a[data-toggle="tab"]:not(:first)').on('shown', function (e) {
$(e.target).parents('.tabbable').removeClass('first-tab-active'); $(e.target).parents('.tabbable').removeClass('first-tab-active');
}); });
$('.form-switcher a:first').tab('show');
$('a[data-toggle="tab"]').click(function(){
document.cookie="tabstyle=" + this.name + "; path=/";
});
// Store tab preference in cookies & display appropriate tab on load.
var selectedTab = null;
var selectedTabName = getCookie('tabstyle');
if (selectedTabName) {
selectedTab = $('.form-switcher a[name=' + selectedTabName + ']');
}
if (selectedTab && selectedTab.length > 0) {
// Display whichever tab is selected.
selectedTab.tab('show');
} else {
// If no tab selected, display rightmost tab.
$('.form-switcher a:first').tab('show');
}
...@@ -128,17 +128,17 @@ ...@@ -128,17 +128,17 @@
<div {% if post_form %}class="tabbable"{% endif %}> <div {% if post_form %}class="tabbable"{% endif %}>
{% if post_form %} {% if post_form %}
<ul class="nav nav-tabs form-switcher"> <ul class="nav nav-tabs form-switcher">
<li><a href="#object-form" data-toggle="tab">HTML form</a></li> <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li>
<li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>
</ul> </ul>
{% endif %} {% endif %}
<div class="well tab-content"> <div class="well tab-content">
{% if post_form %} {% if post_form %}
<div class="tab-pane" id="object-form"> <div class="tab-pane" id="object-form">
{% with form=post_form %} {% with form=post_form %}
<form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<fieldset> <fieldset>
{% include "rest_framework/form.html" %} {{ post_form }}
<div class="form-actions"> <div class="form-actions">
<button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button> <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
</div> </div>
...@@ -167,23 +167,21 @@ ...@@ -167,23 +167,21 @@
<div {% if put_form %}class="tabbable"{% endif %}> <div {% if put_form %}class="tabbable"{% endif %}>
{% if put_form %} {% if put_form %}
<ul class="nav nav-tabs form-switcher"> <ul class="nav nav-tabs form-switcher">
<li><a href="#object-form" data-toggle="tab">HTML form</a></li> <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li>
<li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>
</ul> </ul>
{% endif %} {% endif %}
<div class="well tab-content"> <div class="well tab-content">
{% if put_form %} {% if put_form %}
<div class="tab-pane" id="object-form"> <div class="tab-pane" id="object-form">
{% with form=put_form %} <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
<fieldset> <fieldset>
{% include "rest_framework/form.html" %} {{ put_form }}
<div class="form-actions"> <div class="form-actions">
<button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button> <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
</div> </div>
</fieldset> </fieldset>
</form> </form>
{% endwith %}
</div> </div>
{% endif %} {% endif %}
<div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form"> <div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form">
......
...@@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient): ...@@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient):
""" """
self.handler._force_user = user self.handler._force_user = user
self.handler._force_token = token self.handler._force_token = token
if user is None:
self.logout() # Also clear any possible session info if required
def request(self, **kwargs): def request(self, **kwargs):
# Ensure that any credentials set get added to every request. # Ensure that any credentials set get added to every request.
......
...@@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text ...@@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
from rest_framework.tests.description import UTF8_TEST_DOCSTRING from rest_framework.tests.description import UTF8_TEST_DOCSTRING
from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented. # We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring DESCRIPTION = """an example docstring
...@@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase): ...@@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase):
""" """
class MockView(APIView): class MockView(APIView):
pass pass
self.assertEqual(get_view_name(MockView), 'Mock') self.assertEqual(MockView().get_view_name(), 'Mock')
def test_view_description_uses_docstring(self): def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring.""" """Ensure view descriptions are based on the docstring."""
...@@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase): ...@@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #""" # hash style header #"""
self.assertEqual(get_view_description(MockView), DESCRIPTION) self.assertEqual(MockView().get_view_description(), DESCRIPTION)
def test_view_description_supports_unicode(self): def test_view_description_supports_unicode(self):
""" """
...@@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase): ...@@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase):
""" """
self.assertEqual( self.assertEqual(
get_view_description(ViewWithNonASCIICharactersInDocstring), ViewWithNonASCIICharactersInDocstring().get_view_description(),
smart_text(UTF8_TEST_DOCSTRING) smart_text(UTF8_TEST_DOCSTRING)
) )
...@@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase): ...@@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase):
""" """
class MockView(APIView): class MockView(APIView):
pass pass
self.assertEqual(get_view_description(MockView), '') self.assertEqual(MockView().get_view_description(), '')
def test_markdown(self): def test_markdown(self):
""" """
......
...@@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase): ...@@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
def test_from_native_empty(self):
"""
Make sure from_native() returns None on empty param.
"""
f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
result = f.from_native('')
self.assertEqual(result, None)
class EmailFieldTests(TestCase): class EmailFieldTests(TestCase):
""" """
...@@ -896,3 +904,12 @@ class CustomIntegerField(TestCase): ...@@ -896,3 +904,12 @@ class CustomIntegerField(TestCase):
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
class BooleanField(TestCase):
"""
Tests for BooleanField
"""
def test_boolean_required(self):
class BooleanRequiredSerializer(serializers.Serializer):
bool_field = serializers.BooleanField(required=True)
self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
...@@ -7,13 +7,13 @@ import datetime ...@@ -7,13 +7,13 @@ import datetime
class UploadedFile(object): class UploadedFile(object):
def __init__(self, file, created=None): def __init__(self, file=None, created=None):
self.file = file self.file = file
self.created = created or datetime.datetime.now() self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer): class UploadedFileSerializer(serializers.Serializer):
file = serializers.FileField() file = serializers.FileField(required=False)
created = serializers.DateTimeField() created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
...@@ -47,5 +47,36 @@ class FileSerializerTests(TestCase): ...@@ -47,5 +47,36 @@ class FileSerializerTests(TestCase):
now = datetime.datetime.now() now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now}) serializer = UploadedFileSerializer(data={'created': now})
self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, now)
self.assertIsNone(serializer.object.file)
def test_remove_with_empty_string(self):
"""
Passing empty string as data should cause file to be removed
Test for:
https://github.com/tomchristie/django-rest-framework/issues/937
"""
now = datetime.datetime.now()
file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt'
file.size = len(file.getvalue())
uploaded_file = UploadedFile(file=file, created=now)
serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, uploaded_file.created)
self.assertIsNone(serializer.object.file)
def test_validation_error_with_non_file(self):
"""
Passing non-files should raise a validation error.
"""
now = datetime.datetime.now()
errmsg = 'No file was submitted. Check the encoding type on the form.'
serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertIn('file', serializer.errors) self.assertEqual(serializer.errors, {'file': [errmsg]})
...@@ -272,6 +272,48 @@ class TestInstanceView(TestCase): ...@@ -272,6 +272,48 @@ class TestInstanceView(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) self.assertEqual(response.data, expected)
def test_options_before_instance_create(self):
"""
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
before the instance has been created
"""
request = factory.options('/999')
with self.assertNumQueries(1):
response = self.view(request, pk=999).render()
expected = {
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
],
'renders': [
'application/json',
'text/html'
],
'name': 'Instance',
'description': 'Example description for OPTIONS.',
'actions': {
'PUT': {
'text': {
'max_length': 100,
'read_only': False,
'required': True,
'type': 'string',
'label': 'Text comes here',
'help_text': 'Text description.'
},
'id': {
'read_only': True,
'required': False,
'type': 'integer',
'label': 'ID',
},
}
}
}
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
def test_get_instance_view_incorrect_arg(self): def test_get_instance_view_incorrect_arg(self):
""" """
GET requests with an incorrect pk type, should raise 404, not 500. GET requests with an incorrect pk type, should raise 404, not 500.
...@@ -338,6 +380,17 @@ class TestInstanceView(TestCase): ...@@ -338,6 +380,17 @@ class TestInstanceView(TestCase):
new_obj = SlugBasedModel.objects.get(slug='test_slug') new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar') self.assertEqual(new_obj.text, 'foobar')
def test_patch_cannot_create_an_object(self):
"""
PATCH requests should not be able to create objects.
"""
data = {'text': 'foobar'}
request = factory.patch('/999', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=999).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertFalse(self.objects.filter(id=999).exists())
class TestOverriddenGetObject(TestCase): class TestOverriddenGetObject(TestCase):
""" """
......
...@@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView): ...@@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView):
paginate_by_param = 'page_size' paginate_by_param = 'page_size'
class MaxPaginateByView(generics.ListAPIView):
"""
View for testing custom max_paginate_by usage
"""
model = BasicModel
paginate_by = 3
max_paginate_by = 5
paginate_by_param = 'page_size'
class IntegrationTestPagination(TestCase): class IntegrationTestPagination(TestCase):
""" """
Integration tests for paginated list views. Integration tests for paginated list views.
...@@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase): ...@@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase):
self.assertEqual(response.data['results'], self.data[:5]) self.assertEqual(response.data['results'], self.data[:5])
class TestMaxPaginateByParam(TestCase):
"""
Tests for list views with max_paginate_by kwarg
"""
def setUp(self):
"""
Create 13 BasicModel instances.
"""
for i in range(13):
BasicModel(text=i).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = MaxPaginateByView.as_view()
def test_max_paginate_by(self):
"""
If max_paginate_by is set, it should limit page size for the view.
"""
request = factory.get('/?page_size=10')
response = self.view(request).render()
self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5])
def test_max_paginate_by_without_page_size_param(self):
"""
If max_paginate_by is set, but client does not specifiy page_size,
standard `paginate_by` behavior should be used.
"""
request = factory.get('/')
response = self.view(request).render()
self.assertEqual(response.data['results'], self.data[:3])
### Tests for context in pagination serializers ### Tests for context in pagination serializers
class CustomField(serializers.Field): class CustomField(serializers.Field):
......
...@@ -283,6 +283,15 @@ class PKForeignKeyTests(TestCase): ...@@ -283,6 +283,15 @@ class PKForeignKeyTests(TestCase):
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field is required.']}) self.assertEqual(serializer.errors, {'target': ['This field is required.']})
def test_foreign_key_with_empty(self):
"""
Regression test for #1072
https://github.com/tomchristie/django-rest-framework/issues/1072
"""
serializer = NullableForeignKeySourceSerializer()
self.assertEqual(serializer.data['target'], None)
class PKNullableForeignKeyTests(TestCase): class PKNullableForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
......
...@@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase): ...@@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase):
self.urls = self.router.urls self.urls = self.router.urls
def test_urls_can_have_trailing_slash_removed(self): def test_urls_can_have_trailing_slash_removed(self):
expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)): for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern) self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
......
...@@ -17,8 +17,18 @@ def view(request): ...@@ -17,8 +17,18 @@ def view(request):
}) })
@api_view(['GET', 'POST'])
def session_view(request):
active_session = request.session.get('active_session', False)
request.session['active_session'] = True
return Response({
'active_session': active_session
})
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^view/$', view), url(r'^view/$', view),
url(r'^session-view/$', session_view),
) )
...@@ -46,6 +56,26 @@ class TestAPITestClient(TestCase): ...@@ -46,6 +56,26 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/') response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example') self.assertEqual(response.data['user'], 'example')
def test_force_authenticate_with_sessions(self):
"""
Setting `.force_authenticate()` forcibly authenticates each request.
"""
user = User.objects.create_user('example', 'example@example.com')
self.client.force_authenticate(user)
# First request does not yet have an active session
response = self.client.get('/session-view/')
self.assertEqual(response.data['active_session'], False)
# Subsequant requests have an active session
response = self.client.get('/session-view/')
self.assertEqual(response.data['active_session'], True)
# Force authenticating as `None` should also logout the user session.
self.client.force_authenticate(None)
response = self.client.get('/session-view/')
self.assertEqual(response.data['active_session'], False)
def test_csrf_exempt_by_default(self): def test_csrf_exempt_by_default(self):
""" """
By default, the test client is CSRF exempt. By default, the test client is CSRF exempt.
......
...@@ -32,6 +32,16 @@ def basic_view(request): ...@@ -32,6 +32,16 @@ def basic_view(request):
return {'method': 'PATCH', 'data': request.DATA} return {'method': 'PATCH', 'data': request.DATA}
class ErrorView(APIView):
def get(self, request, *args, **kwargs):
raise Exception
@api_view(['GET'])
def error_view(request):
raise Exception
def sanitise_json_error(error_dict): def sanitise_json_error(error_dict):
""" """
Exact contents of JSON error messages depend on the installed version Exact contents of JSON error messages depend on the installed version
...@@ -99,3 +109,34 @@ class FunctionBasedViewIntegrationTests(TestCase): ...@@ -99,3 +109,34 @@ class FunctionBasedViewIntegrationTests(TestCase):
} }
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected) self.assertEqual(sanitise_json_error(response.data), expected)
class TestCustomExceptionHandler(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc):
return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
request = factory.get('/', content_type='application/json')
response = view(request)
expected = 'Error!'
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, expected)
def test_function_based_view_exception_handler(self):
view = error_view
request = factory.get('/', content_type='application/json')
response = view(request)
expected = 'Error!'
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, expected)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Provides various throttling policies. Provides various throttling policies.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.cache import cache from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
import time import time
...@@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle): ...@@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache. Previous request information used for throttling is stored in the cache.
""" """
cache = default_cache
timer = time.time timer = time.time
cache_format = 'throtte_%(scope)s_%(ident)s' cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None scope = None
...@@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle): ...@@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle):
if self.key is None: if self.key is None:
return True return True
self.history = cache.get(self.key, []) self.history = self.cache.get(self.key, [])
self.now = self.timer() self.now = self.timer()
# Drop any requests from the history which have now passed the # Drop any requests from the history which have now passed the
...@@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle): ...@@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle):
into the cache. into the cache.
""" """
self.history.insert(0, self.now) self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration) self.cache.set(self.key, self.history, self.duration)
return True return True
def throttle_failure(self): def throttle_failure(self):
...@@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle): ...@@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated(): if request.user.is_authenticated():
return None # Only throttle unauthenticated requests. return None # Only throttle unauthenticated requests.
ident = request.META.get('REMOTE_ADDR', None) ident = request.META.get('HTTP_X_FORWARDED_FOR')
if ident is None:
ident = request.META.get('REMOTE_ADDR')
return self.cache_format % { return self.cache_format % {
'scope': self.scope, 'scope': self.scope,
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix from django.core.urlresolvers import resolve, get_script_prefix
from rest_framework.utils.formatting import get_view_name
def get_breadcrumbs(url): def get_breadcrumbs(url):
...@@ -9,8 +8,11 @@ def get_breadcrumbs(url): ...@@ -9,8 +8,11 @@ def get_breadcrumbs(url):
tuple of (name, url). tuple of (name, url).
""" """
from rest_framework.settings import api_settings
from rest_framework.views import APIView from rest_framework.views import APIView
view_name_func = api_settings.VIEW_NAME_FUNCTION
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
""" """
Add tuples of (name, url) to the breadcrumbs list, Add tuples of (name, url) to the breadcrumbs list,
...@@ -30,7 +32,7 @@ def get_breadcrumbs(url): ...@@ -30,7 +32,7 @@ def get_breadcrumbs(url):
# Probably an optional trailing slash. # Probably an optional trailing slash.
if not seen or seen[-1] != view: if not seen or seen[-1] != view:
suffix = getattr(view, 'suffix', None) suffix = getattr(view, 'suffix', None)
name = get_view_name(view.cls, suffix) name = view_name_func(cls, suffix)
breadcrumbs_list.insert(0, (name, prefix + url)) breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view) seen.append(view)
......
...@@ -5,11 +5,13 @@ from __future__ import unicode_literals ...@@ -5,11 +5,13 @@ from __future__ import unicode_literals
from django.utils.html import escape from django.utils.html import escape
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from rest_framework.compat import apply_markdown, smart_text from rest_framework.compat import apply_markdown
from rest_framework.settings import api_settings
from textwrap import dedent
import re import re
def _remove_trailing_string(content, trailing): def remove_trailing_string(content, trailing):
""" """
Strip trailing component `trailing` from `content` if it exists. Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view classes. Used when generating names from view classes.
...@@ -19,10 +21,14 @@ def _remove_trailing_string(content, trailing): ...@@ -19,10 +21,14 @@ def _remove_trailing_string(content, trailing):
return content return content
def _remove_leading_indent(content): def dedent(content):
""" """
Remove leading indent from a block of text. Remove leading indent from a block of text.
Used when generating descriptions from docstrings. Used when generating descriptions from docstrings.
Note that python's `textwrap.dedent` doesn't quite cut it,
as it fails to dedent multiline docstrings that include
unindented text on the initial line.
""" """
whitespace_counts = [len(line) - len(line.lstrip(' ')) whitespace_counts = [len(line) - len(line.lstrip(' '))
for line in content.splitlines()[1:] if line.lstrip()] for line in content.splitlines()[1:] if line.lstrip()]
...@@ -31,11 +37,10 @@ def _remove_leading_indent(content): ...@@ -31,11 +37,10 @@ def _remove_leading_indent(content):
if whitespace_counts: if whitespace_counts:
whitespace_pattern = '^' + (' ' * min(whitespace_counts)) whitespace_pattern = '^' + (' ' * min(whitespace_counts))
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
content = content.strip('\n')
return content
return content.strip()
def _camelcase_to_spaces(content): def camelcase_to_spaces(content):
""" """
Translate 'CamelCaseNames' to 'Camel Case Names'. Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes. Used when generating names from view classes.
...@@ -44,31 +49,6 @@ def _camelcase_to_spaces(content): ...@@ -44,31 +49,6 @@ def _camelcase_to_spaces(content):
content = re.sub(camelcase_boundry, ' \\1', content).strip() content = re.sub(camelcase_boundry, ' \\1', content).strip()
return ' '.join(content.split('_')).title() return ' '.join(content.split('_')).title()
def get_view_name(cls, suffix=None):
"""
Return a formatted name for an `APIView` class or `@api_view` function.
"""
name = cls.__name__
name = _remove_trailing_string(name, 'View')
name = _remove_trailing_string(name, 'ViewSet')
name = _camelcase_to_spaces(name)
if suffix:
name += ' ' + suffix
return name
def get_view_description(cls, html=False):
"""
Return a description for an `APIView` class or `@api_view` function.
"""
description = cls.__doc__ or ''
description = _remove_leading_indent(smart_text(description))
if html:
return markup_description(description)
return description
def markup_description(description): def markup_description(description):
""" """
Apply HTML markup to the given description. Apply HTML markup to the given description.
......
...@@ -8,16 +8,79 @@ from django.http import Http404 ...@@ -8,16 +8,79 @@ from django.http import Http404
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions from rest_framework import status, exceptions
from rest_framework.compat import View, HttpResponseBase from rest_framework.compat import smart_text, HttpResponseBase, View
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils.formatting import get_view_name, get_view_description from rest_framework.utils import formatting
def get_view_name(view_cls, suffix=None):
"""
Given a view class, return a textual name to represent the view.
This name is used in the browsable API, and in OPTIONS responses.
This function is the default for the `VIEW_NAME_FUNCTION` setting.
"""
name = view_cls.__name__
name = formatting.remove_trailing_string(name, 'View')
name = formatting.remove_trailing_string(name, 'ViewSet')
name = formatting.camelcase_to_spaces(name)
if suffix:
name += ' ' + suffix
return name
def get_view_description(view_cls, html=False):
"""
Given a view class, return a textual description to represent the view.
This name is used in the browsable API, and in OPTIONS responses.
This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
"""
description = view_cls.__doc__ or ''
description = formatting.dedent(smart_text(description))
if html:
return formatting.markup_description(description)
return description
def exception_handler(exc):
"""
Returns the response that should be used for any given exception.
By default we handle the REST framework `APIException`, and also
Django's builtin `Http404` and `PermissionDenied` exceptions.
Any unhandled exceptions may return `None`, which will cause a 500 error
to be raised.
"""
if isinstance(exc, exceptions.APIException):
headers = {}
if getattr(exc, 'auth_header', None):
headers['WWW-Authenticate'] = exc.auth_header
if getattr(exc, 'wait', None):
headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
return Response({'detail': exc.detail},
status=exc.status_code,
headers=headers)
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND)
elif isinstance(exc, PermissionDenied):
return Response({'detail': 'Permission denied'},
status=status.HTTP_403_FORBIDDEN)
# Note: Unhandled exceptions will raise a 500 error.
return None
class APIView(View): class APIView(View):
settings = api_settings
# The following policies may be set at either globally, or per-view.
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
parser_classes = api_settings.DEFAULT_PARSER_CLASSES parser_classes = api_settings.DEFAULT_PARSER_CLASSES
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
...@@ -25,6 +88,9 @@ class APIView(View): ...@@ -25,6 +88,9 @@ class APIView(View):
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
# Allow dependancy injection of other settings to make testing easier.
settings = api_settings
@classmethod @classmethod
def as_view(cls, **initkwargs): def as_view(cls, **initkwargs):
""" """
...@@ -110,6 +176,22 @@ class APIView(View): ...@@ -110,6 +176,22 @@ class APIView(View):
'request': getattr(self, 'request', None) 'request': getattr(self, 'request', None)
} }
def get_view_name(self):
"""
Return the view name, as used in OPTIONS responses and in the
browsable API.
"""
func = self.settings.VIEW_NAME_FUNCTION
return func(self.__class__, getattr(self, 'suffix', None))
def get_view_description(self, html=False):
"""
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
"""
func = self.settings.VIEW_DESCRIPTION_FUNCTION
return func(self.__class__, html)
# API policy instantiation methods # API policy instantiation methods
def get_format_suffix(self, **kwargs): def get_format_suffix(self, **kwargs):
...@@ -269,34 +351,24 @@ class APIView(View): ...@@ -269,34 +351,24 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response, Handle any exception that occurs, by returning an appropriate response,
or re-raising the error. or re-raising the error.
""" """
if isinstance(exc, exceptions.Throttled) and exc.wait is not None:
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
if isinstance(exc, (exceptions.NotAuthenticated, if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)): exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403 # WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request) auth_header = self.get_authenticate_header(self.request)
if auth_header: if auth_header:
self.headers['WWW-Authenticate'] = auth_header exc.auth_header = auth_header
else: else:
exc.status_code = status.HTTP_403_FORBIDDEN exc.status_code = status.HTTP_403_FORBIDDEN
if isinstance(exc, exceptions.APIException): response = self.settings.EXCEPTION_HANDLER(exc)
return Response({'detail': exc.detail},
status=exc.status_code, if response is None:
exception=True)
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND,
exception=True)
elif isinstance(exc, PermissionDenied):
return Response({'detail': 'Permission denied'},
status=status.HTTP_403_FORBIDDEN,
exception=True)
raise raise
response.exception = True
return response
# Note: session based authentication is explicitly CSRF validated, # Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt. # all other authentication is CSRF exempt.
@csrf_exempt @csrf_exempt
...@@ -342,16 +414,12 @@ class APIView(View): ...@@ -342,16 +414,12 @@ class APIView(View):
Return a dictionary of metadata about the view. Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests. Used to return responses for OPTIONS requests.
""" """
# This is used by ViewSets to disambiguate instance vs list views
view_name_suffix = getattr(self, 'suffix', None)
# By default we can't provide any form-like information, however the # By default we can't provide any form-like information, however the
# generic views override this implementation and add additional # generic views override this implementation and add additional
# information for POST and PUT methods, based on the serializer. # information for POST and PUT methods, based on the serializer.
ret = SortedDict() ret = SortedDict()
ret['name'] = get_view_name(self.__class__, view_name_suffix) ret['name'] = self.get_view_name()
ret['description'] = get_view_description(self.__class__) ret['description'] = self.get_view_description()
ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
ret['parses'] = [parser.media_type for parser in self.parser_classes] ret['parses'] = [parser.media_type for parser in self.parser_classes]
return ret return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment