Commit d98245ac by Carlton Gibson

Merge branch '2.4.0' of github.com:tomchristie/django-rest-framework into #1559

Conflicts:
	docs/topics/release-notes.md
parents 3f727ce7 2489e38a
...@@ -5,17 +5,18 @@ python: ...@@ -5,17 +5,18 @@ python:
- "2.7" - "2.7"
- "3.2" - "3.2"
- "3.3" - "3.3"
- "3.4"
env: env:
- DJANGO="https://www.djangoproject.com/download/1.7b2/tarball/" - DJANGO="https://www.djangoproject.com/download/1.7.b4/tarball/"
- DJANGO="django==1.6.3" - DJANGO="django==1.6.5"
- DJANGO="django==1.5.6" - DJANGO="django==1.5.8"
- DJANGO="django==1.4.11" - DJANGO="django==1.4.13"
- DJANGO="django==1.3.7"
install: install:
- pip install $DJANGO - pip install $DJANGO
- pip install defusedxml==0.3 Pillow==2.3.0 - pip install defusedxml==0.3 Pillow==2.3.0
- pip install pytest-django==2.6.1
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.4; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.4; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi"
...@@ -23,22 +24,19 @@ install: ...@@ -23,22 +24,19 @@ install:
- "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi" - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi"
- "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.7; fi" - "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.7; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} == '3' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} == '3' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7b2/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi" - "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7.b4/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:
- python rest_framework/runtests/runtests.py - py.test
matrix: matrix:
exclude: exclude:
- python: "2.6" - python: "2.6"
env: DJANGO="https://www.djangoproject.com/download/1.7b2/tarball/" env: DJANGO="https://www.djangoproject.com/download/1.7.b4/tarball/"
- python: "3.2" - python: "3.2"
env: DJANGO="django==1.4.11" env: DJANGO="django==1.4.13"
- python: "3.2"
env: DJANGO="django==1.3.7"
- python: "3.3"
env: DJANGO="django==1.4.11"
- python: "3.3" - python: "3.3"
env: DJANGO="django==1.3.7" env: DJANGO="django==1.4.13"
- python: "3.4"
env: DJANGO="django==1.4.13"
...@@ -65,7 +65,7 @@ To run the tests, clone the repository, and then: ...@@ -65,7 +65,7 @@ To run the tests, clone the repository, and then:
pip install -r optionals.txt pip install -r optionals.txt
# Run the tests # Run the tests
rest_framework/runtests/runtests.py py.test
You can also use the excellent [`tox`][tox] testing tool to run the tests against all supported versions of Python and Django. Install `tox` globally, and then simply run: You can also use the excellent [`tox`][tox] testing tool to run the tests against all supported versions of Python and Django. Install `tox` globally, and then simply run:
......
def pytest_configure():
from django.conf import settings
settings.configure(
DEBUG_PROPAGATE_EXCEPTIONS=True,
DATABASES={'default': {'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:'}},
SECRET_KEY='not very secret in tests',
USE_I18N=True,
USE_L10N=True,
STATIC_URL='/static/',
ROOT_URLCONF='tests.urls',
TEMPLATE_LOADERS=(
'django.template.loaders.filesystem.Loader',
'django.template.loaders.app_directories.Loader',
),
MIDDLEWARE_CLASSES=(
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
),
INSTALLED_APPS=(
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.sites',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'rest_framework.authtoken',
'tests',
'tests.accounts',
'tests.records',
'tests.users',
),
PASSWORD_HASHERS=(
'django.contrib.auth.hashers.SHA1PasswordHasher',
'django.contrib.auth.hashers.PBKDF2PasswordHasher',
'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
'django.contrib.auth.hashers.BCryptPasswordHasher',
'django.contrib.auth.hashers.MD5PasswordHasher',
'django.contrib.auth.hashers.CryptPasswordHasher',
),
)
try:
import oauth_provider
import oauth2
except ImportError:
pass
else:
settings.INSTALLED_APPS += (
'oauth_provider',
)
try:
import provider
except ImportError:
pass
else:
settings.INSTALLED_APPS += (
'provider',
'provider.oauth2',
)
# guardian is optional
try:
import guardian
except ImportError:
pass
else:
settings.ANONYMOUS_USER_ID = -1
settings.AUTHENTICATION_BACKENDS = (
'django.contrib.auth.backends.ModelBackend', # default
'guardian.backends.ObjectPermissionBackend',
)
settings.INSTALLED_APPS += (
'guardian',
)
try:
import django
django.setup()
except AttributeError:
pass
...@@ -119,7 +119,7 @@ Unauthenticated responses that are denied permission will result in an `HTTP 401 ...@@ -119,7 +119,7 @@ Unauthenticated responses that are denied permission will result in an `HTTP 401
This authentication scheme uses a simple token-based HTTP Authentication scheme. Token authentication is appropriate for client-server setups, such as native desktop and mobile clients. This authentication scheme uses a simple token-based HTTP Authentication scheme. Token authentication is appropriate for client-server setups, such as native desktop and mobile clients.
To use the `TokenAuthentication` scheme, include `rest_framework.authtoken` in your `INSTALLED_APPS` setting: To use the `TokenAuthentication` scheme you'll need to [configure the authentication classes](#setting-the-authentication-scheme) to include `TokenAuthentication`, and additionally include `rest_framework.authtoken` in your `INSTALLED_APPS` setting:
INSTALLED_APPS = ( INSTALLED_APPS = (
... ...
......
...@@ -164,11 +164,12 @@ Corresponds to `django.db.models.fields.BooleanField`. ...@@ -164,11 +164,12 @@ Corresponds to `django.db.models.fields.BooleanField`.
## CharField ## CharField
A text representation, optionally validates the text to be shorter than `max_length` and longer than `min_length`. A text representation, optionally validates the text to be shorter than `max_length` and longer than `min_length`.
If `allow_none` is `False` (default), `None` values will be converted to an empty string.
Corresponds to `django.db.models.fields.CharField` Corresponds to `django.db.models.fields.CharField`
or `django.db.models.fields.TextField`. or `django.db.models.fields.TextField`.
**Signature:** `CharField(max_length=None, min_length=None)` **Signature:** `CharField(max_length=None, min_length=None, allow_none=False)`
## URLField ## URLField
...@@ -184,7 +185,9 @@ Corresponds to `django.db.models.fields.SlugField`. ...@@ -184,7 +185,9 @@ Corresponds to `django.db.models.fields.SlugField`.
## ChoiceField ## ChoiceField
A field that can accept a value out of a limited set of choices. A field that can accept a value out of a limited set of choices. Optionally takes a `blank_display_value` parameter that customizes the display value of an empty choice.
**Signature:** `ChoiceField(choices=(), blank_display_value=None)`
## EmailField ## EmailField
......
...@@ -187,7 +187,7 @@ Remember that the `pre_save()` method is not called by `GenericAPIView` itself, ...@@ -187,7 +187,7 @@ Remember that the `pre_save()` method is not called by `GenericAPIView` itself,
You won't typically need to override the following methods, although you might need to call into them if you're writing custom views using `GenericAPIView`. You won't typically need to override the following methods, although you might need to call into them if you're writing custom views using `GenericAPIView`.
* `get_serializer_context(self)` - Returns a dictionary containing any extra context that should be supplied to the serializer. Defaults to including `'request'`, `'view'` and `'format'` keys. * `get_serializer_context(self)` - Returns a dictionary containing any extra context that should be supplied to the serializer. Defaults to including `'request'`, `'view'` and `'format'` keys.
* `get_serializer(self, instance=None, data=None, files=None, many=False, partial=False)` - Returns a serializer instance. * `get_serializer(self, instance=None, data=None, files=None, many=False, partial=False, allow_add_remove=False)` - Returns a serializer instance.
* `get_pagination_serializer(self, page)` - Returns a serializer instance to use with paginated data. * `get_pagination_serializer(self, page)` - Returns a serializer instance to use with paginated data.
* `paginate_queryset(self, queryset)` - Paginate a queryset if required, either returning a page object, or `None` if pagination is not configured for this view. * `paginate_queryset(self, queryset)` - Paginate a queryset if required, either returning a page object, or `None` if pagination is not configured for this view.
* `filter_queryset(self, queryset)` - Given a queryset, filter it with whichever filter backends are in use, returning a new queryset. * `filter_queryset(self, queryset)` - Given a queryset, filter it with whichever filter backends are in use, returning a new queryset.
......
...@@ -51,36 +51,41 @@ This means you'll need to explicitly set the `base_name` argument when registeri ...@@ -51,36 +51,41 @@ This means you'll need to explicitly set the `base_name` argument when registeri
### Extra link and actions ### Extra link and actions
Any methods on the viewset decorated with `@link` or `@action` will also be routed. Any methods on the viewset decorated with `@detail_route` or `@list_route` will also be routed.
For example, given a method like this on the `UserViewSet` class: For example, given a method like this on the `UserViewSet` class:
from myapp.permissions import IsAdminOrIsSelf from myapp.permissions import IsAdminOrIsSelf
from rest_framework.decorators import action from rest_framework.decorators import detail_route
@action(permission_classes=[IsAdminOrIsSelf]) class UserViewSet(ModelViewSet):
def set_password(self, request, pk=None):
... ...
@detail_route(methods=['post'], permission_classes=[IsAdminOrIsSelf])
def set_password(self, request, pk=None):
...
The following URL pattern would additionally be generated: The following URL pattern would additionally be generated:
* URL pattern: `^users/{pk}/set_password/$` Name: `'user-set-password'` * URL pattern: `^users/{pk}/set_password/$` Name: `'user-set-password'`
For more information see the viewset documentation on [marking extra actions for routing][route-decorators].
# API Guide # API Guide
## SimpleRouter ## SimpleRouter
This router includes routes for the standard set of `list`, `create`, `retrieve`, `update`, `partial_update` and `destroy` actions. The viewset can also mark additional methods to be routed, using the `@link` or `@action` decorators. This router includes routes for the standard set of `list`, `create`, `retrieve`, `update`, `partial_update` and `destroy` actions. The viewset can also mark additional methods to be routed, using the `@detail_route` or `@list_route` decorators.
<table border=1> <table border=1>
<tr><th>URL Style</th><th>HTTP Method</th><th>Action</th><th>URL Name</th></tr> <tr><th>URL Style</th><th>HTTP Method</th><th>Action</th><th>URL Name</th></tr>
<tr><td rowspan=2>{prefix}/</td><td>GET</td><td>list</td><td rowspan=2>{basename}-list</td></tr></tr> <tr><td rowspan=2>{prefix}/</td><td>GET</td><td>list</td><td rowspan=2>{basename}-list</td></tr></tr>
<tr><td>POST</td><td>create</td></tr> <tr><td>POST</td><td>create</td></tr>
<tr><td>{prefix}/{methodname}/</td><td>GET, or as specified by `methods` argument</td><td>`@list_route` decorated method</td><td>{basename}-{methodname}</td></tr>
<tr><td rowspan=4>{prefix}/{lookup}/</td><td>GET</td><td>retrieve</td><td rowspan=4>{basename}-detail</td></tr></tr> <tr><td rowspan=4>{prefix}/{lookup}/</td><td>GET</td><td>retrieve</td><td rowspan=4>{basename}-detail</td></tr></tr>
<tr><td>PUT</td><td>update</td></tr> <tr><td>PUT</td><td>update</td></tr>
<tr><td>PATCH</td><td>partial_update</td></tr> <tr><td>PATCH</td><td>partial_update</td></tr>
<tr><td>DELETE</td><td>destroy</td></tr> <tr><td>DELETE</td><td>destroy</td></tr>
<tr><td rowspan=2>{prefix}/{lookup}/{methodname}/</td><td>GET</td><td>@link decorated method</td><td rowspan=2>{basename}-{methodname}</td></tr> <tr><td>{prefix}/{lookup}/{methodname}/</td><td>GET, or as specified by `methods` argument</td><td>`@detail_route` decorated method</td><td>{basename}-{methodname}</td></tr>
<tr><td>POST</td><td>@action decorated method</td></tr>
</table> </table>
By default the URLs created by `SimpleRouter` are appended with a trailing slash. By default the URLs created by `SimpleRouter` are appended with a trailing slash.
...@@ -90,6 +95,12 @@ This behavior can be modified by setting the `trailing_slash` argument to `False ...@@ -90,6 +95,12 @@ This behavior can be modified by setting the `trailing_slash` argument to `False
Trailing slashes are conventional in Django, but are not used by default in some other frameworks such as Rails. Which style you choose to use is largely a matter of preference, although some javascript frameworks may expect a particular routing style. Trailing slashes are conventional in Django, but are not used by default in some other frameworks such as Rails. Which style you choose to use is largely a matter of preference, although some javascript frameworks may expect a particular routing style.
The router will match lookup values containing any characters except slashes and period characters. For a more restrictive (or lenient) lookup pattern, set the `lookup_value_regex` attribute on the viewset. For example, you can limit the lookup to valid UUIDs:
class MyModelViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
lookup_field = 'my_model_id'
lookup_value_regex = '[0-9a-f]{32}'
## DefaultRouter ## DefaultRouter
This router is similar to `SimpleRouter` as above, but additionally includes a default API root view, that returns a response containing hyperlinks to all the list views. It also generates routes for optional `.json` style format suffixes. This router is similar to `SimpleRouter` as above, but additionally includes a default API root view, that returns a response containing hyperlinks to all the list views. It also generates routes for optional `.json` style format suffixes.
...@@ -99,12 +110,12 @@ This router is similar to `SimpleRouter` as above, but additionally includes a d ...@@ -99,12 +110,12 @@ This router is similar to `SimpleRouter` as above, but additionally includes a d
<tr><td>[.format]</td><td>GET</td><td>automatically generated root view</td><td>api-root</td></tr></tr> <tr><td>[.format]</td><td>GET</td><td>automatically generated root view</td><td>api-root</td></tr></tr>
<tr><td rowspan=2>{prefix}/[.format]</td><td>GET</td><td>list</td><td rowspan=2>{basename}-list</td></tr></tr> <tr><td rowspan=2>{prefix}/[.format]</td><td>GET</td><td>list</td><td rowspan=2>{basename}-list</td></tr></tr>
<tr><td>POST</td><td>create</td></tr> <tr><td>POST</td><td>create</td></tr>
<tr><td>{prefix}/{methodname}/[.format]</td><td>GET, or as specified by `methods` argument</td><td>`@list_route` decorated method</td><td>{basename}-{methodname}</td></tr>
<tr><td rowspan=4>{prefix}/{lookup}/[.format]</td><td>GET</td><td>retrieve</td><td rowspan=4>{basename}-detail</td></tr></tr> <tr><td rowspan=4>{prefix}/{lookup}/[.format]</td><td>GET</td><td>retrieve</td><td rowspan=4>{basename}-detail</td></tr></tr>
<tr><td>PUT</td><td>update</td></tr> <tr><td>PUT</td><td>update</td></tr>
<tr><td>PATCH</td><td>partial_update</td></tr> <tr><td>PATCH</td><td>partial_update</td></tr>
<tr><td>DELETE</td><td>destroy</td></tr> <tr><td>DELETE</td><td>destroy</td></tr>
<tr><td rowspan=2>{prefix}/{lookup}/{methodname}/[.format]</td><td>GET</td><td>@link decorated method</td><td rowspan=2>{basename}-{methodname}</td></tr> <tr><td>{prefix}/{lookup}/{methodname}/[.format]</td><td>GET, or as specified by `methods` argument</td><td>`@detail_route` decorated method</td><td>{basename}-{methodname}</td></tr>
<tr><td>POST</td><td>@action decorated method</td></tr>
</table> </table>
As with `SimpleRouter` the trailing slashes on the URL routes can be removed by setting the `trailing_slash` argument to `False` when instantiating the router. As with `SimpleRouter` the trailing slashes on the URL routes can be removed by setting the `trailing_slash` argument to `False` when instantiating the router.
...@@ -133,28 +144,87 @@ The arguments to the `Route` named tuple are: ...@@ -133,28 +144,87 @@ The arguments to the `Route` named tuple are:
**initkwargs**: A dictionary of any additional arguments that should be passed when instantiating the view. Note that the `suffix` argument is reserved for identifying the viewset type, used when generating the view name and breadcrumb links. **initkwargs**: A dictionary of any additional arguments that should be passed when instantiating the view. Note that the `suffix` argument is reserved for identifying the viewset type, used when generating the view name and breadcrumb links.
## Customizing dynamic routes
You can also customize how the `@list_route` and `@detail_route` decorators are routed.
To route either or both of these decorators, include a `DynamicListRoute` and/or `DynamicDetailRoute` named tuple in the `.routes` list.
The arguments to `DynamicListRoute` and `DynamicDetailRoute` are:
**url**: A string representing the URL to be routed. May include the same format strings as `Route`, and additionally accepts the `{methodname}` and `{methodnamehyphen}` format strings.
**name**: The name of the URL as used in `reverse` calls. May include the following format strings: `{basename}`, `{methodname}` and `{methodnamehyphen}`.
**initkwargs**: A dictionary of any additional arguments that should be passed when instantiating the view.
## Example ## Example
The following example will only route to the `list` and `retrieve` actions, and does not use the trailing slash convention. The following example will only route to the `list` and `retrieve` actions, and does not use the trailing slash convention.
from rest_framework.routers import Route, SimpleRouter from rest_framework.routers import Route, DynamicDetailRoute, SimpleRouter
class ReadOnlyRouter(SimpleRouter): class CustomReadOnlyRouter(SimpleRouter):
""" """
A router for read-only APIs, which doesn't use trailing slashes. A router for read-only APIs, which doesn't use trailing slashes.
""" """
routes = [ routes = [
Route(url=r'^{prefix}$', Route(
mapping={'get': 'list'}, url=r'^{prefix}$',
name='{basename}-list', mapping={'get': 'list'},
initkwargs={'suffix': 'List'}), name='{basename}-list',
Route(url=r'^{prefix}/{lookup}$', initkwargs={'suffix': 'List'}
mapping={'get': 'retrieve'}, ),
name='{basename}-detail', Route(
initkwargs={'suffix': 'Detail'}) url=r'^{prefix}/{lookup}$',
mapping={'get': 'retrieve'},
name='{basename}-detail',
initkwargs={'suffix': 'Detail'}
),
DynamicDetailRoute(
url=r'^{prefix}/{lookup}/{methodnamehyphen}$',
name='{basename}-{methodnamehyphen}',
initkwargs={}
)
] ]
The `SimpleRouter` class provides another example of setting the `.routes` attribute. Let's take a look at the routes our `CustomReadOnlyRouter` would generate for a simple viewset.
`views.py`:
class UserViewSet(viewsets.ReadOnlyModelViewSet):
"""
A viewset that provides the standard actions
"""
queryset = User.objects.all()
serializer_class = UserSerializer
lookup_field = 'username'
@detail_route()
def group_names(self, request):
"""
Returns a list of all the group names that the given
user belongs to.
"""
user = self.get_object()
groups = user.groups.all()
return Response([group.name for group in groups])
`urls.py`:
router = CustomReadOnlyRouter()
router.register('users', UserViewSet)
urlpatterns = router.urls
The following mappings would be generated...
<table border=1>
<tr><th>URL</th><th>HTTP Method</th><th>Action</th><th>URL Name</th></tr>
<tr><td>/users</td><td>GET</td><td>list</td><td>user-list</td></tr>
<tr><td>/users/{username}</td><td>GET</td><td>retrieve</td><td>user-detail</td></tr>
<tr><td>/users/{username}/group-names</td><td>GET</td><td>group_names</td><td>user-group-names</td></tr>
</table>
For another example of setting the `.routes` attribute, see the source code for the `SimpleRouter` class.
## Advanced custom routers ## Advanced custom routers
...@@ -180,6 +250,7 @@ The [wq.db package][wq.db] provides an advanced [Router][wq.db-router] class (an ...@@ -180,6 +250,7 @@ The [wq.db package][wq.db] provides an advanced [Router][wq.db-router] class (an
app.router.register_model(MyModel) app.router.register_model(MyModel)
[cite]: http://guides.rubyonrails.org/routing.html [cite]: http://guides.rubyonrails.org/routing.html
[route-decorators]: viewsets.html#marking-extra-actions-for-routing
[drf-nested-routers]: https://github.com/alanjds/drf-nested-routers [drf-nested-routers]: https://github.com/alanjds/drf-nested-routers
[wq.db]: http://wq.io/wq.db [wq.db]: http://wq.io/wq.db
[wq.db-router]: http://wq.io/docs/app.py [wq.db-router]: http://wq.io/docs/app.py
...@@ -73,8 +73,8 @@ Sometimes when serializing objects, you may not want to represent everything exa ...@@ -73,8 +73,8 @@ Sometimes when serializing objects, you may not want to represent everything exa
If you need to customize the serialized value of a particular field, you can do this by creating a `transform_<fieldname>` method. For example if you needed to render some markdown from a text field: If you need to customize the serialized value of a particular field, you can do this by creating a `transform_<fieldname>` method. For example if you needed to render some markdown from a text field:
description = serializers.TextField() description = serializers.CharField()
description_html = serializers.TextField(source='description', read_only=True) description_html = serializers.CharField(source='description', read_only=True)
def transform_description_html(self, obj, value): def transform_description_html(self, obj, value):
from django.contrib.markup.templatetags.markup import markdown from django.contrib.markup.templatetags.markup import markdown
...@@ -464,7 +464,7 @@ For more specific requirements such as specifying a different lookup for each fi ...@@ -464,7 +464,7 @@ For more specific requirements such as specifying a different lookup for each fi
model = Account model = Account
fields = ('url', 'account_name', 'users', 'created') fields = ('url', 'account_name', 'users', 'created')
## Overiding the URL field behavior ## Overriding the URL field behavior
The name of the URL field defaults to 'url'. You can override this globally, by using the `URL_FIELD_NAME` setting. The name of the URL field defaults to 'url'. You can override this globally, by using the `URL_FIELD_NAME` setting.
...@@ -478,7 +478,7 @@ You can also override this on a per-serializer basis by using the `url_field_nam ...@@ -478,7 +478,7 @@ You can also override this on a per-serializer basis by using the `url_field_nam
**Note**: The generic view implementations normally generate a `Location` header in response to successful `POST` requests. Serializers using `url_field_name` option will not have this header automatically included by the view. If you need to do so you will ned to also override the view's `get_success_headers()` method. **Note**: The generic view implementations normally generate a `Location` header in response to successful `POST` requests. Serializers using `url_field_name` option will not have this header automatically included by the view. If you need to do so you will ned to also override the view's `get_success_headers()` method.
You can also overide the URL field's view name and lookup field without overriding the field explicitly, by using the `view_name` and `lookup_field` options, like so: You can also override the URL field's view name and lookup field without overriding the field explicitly, by using the `view_name` and `lookup_field` options, like so:
class AccountSerializer(serializers.HyperlinkedModelSerializer): class AccountSerializer(serializers.HyperlinkedModelSerializer):
class Meta: class Meta:
......
...@@ -377,5 +377,11 @@ The name of a parameter in the URL conf that may be used to provide a format suf ...@@ -377,5 +377,11 @@ The name of a parameter in the URL conf that may be used to provide a format suf
Default: `'format'` Default: `'format'`
#### NUM_PROXIES
An integer of 0 or more, that may be used to specify the number of application proxies that the API runs behind. This allows throttling to more accurately identify client IP addresses. If set to `None` then less strict IP matching will be used by the throttle classes.
Default: `None`
[cite]: http://www.python.org/dev/peps/pep-0020/ [cite]: http://www.python.org/dev/peps/pep-0020/
[strftime]: http://docs.python.org/2/library/time.html#time.strftime [strftime]: http://docs.python.org/2/library/time.html#time.strftime
...@@ -35,7 +35,7 @@ The default throttling policy may be set globally, using the `DEFAULT_THROTTLE_C ...@@ -35,7 +35,7 @@ The default throttling policy may be set globally, using the `DEFAULT_THROTTLE_C
'DEFAULT_THROTTLE_RATES': { 'DEFAULT_THROTTLE_RATES': {
'anon': '100/day', 'anon': '100/day',
'user': '1000/day' 'user': '1000/day'
} }
} }
The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period. The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period.
...@@ -66,6 +66,16 @@ Or, if you're using the `@api_view` decorator with function based views. ...@@ -66,6 +66,16 @@ Or, if you're using the `@api_view` decorator with function based views.
} }
return Response(content) return Response(content)
## How clients are identified
The `X-Forwarded-For` and `Remote-Addr` HTTP headers are used to uniquely identify client IP addresses for throttling. If the `X-Forwarded-For` header is present then it will be used, otherwise the value of the `Remote-Addr` header will be used.
If you need to strictly identify unique client IP addresses, you'll need to first configure the number of application proxies that the API runs behind by setting the `NUM_PROXIES` setting. This setting should be an integer of zero or more. If set to non-zero then the client IP will be identified as being the last IP address in the `X-Forwarded-For` header, once any application proxy IP addresses have first been excluded. If set to zero, then the `Remote-Addr` header will always be used as the identifying IP address.
It is important to understand that if you configure the `NUM_PROXIES` setting, then all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
Further context on how the `X-Forwarded-For` header works, and identifing a remote client IP can be [found here][identifing-clients].
## Setting up the cache ## Setting up the cache
The throttle classes provided by REST framework use Django's cache backend. You should make sure that you've set appropriate [cache settings][cache-setting]. The default value of `LocMemCache` backend should be okay for simple setups. See Django's [cache documentation][cache-docs] for more details. The throttle classes provided by REST framework use Django's cache backend. You should make sure that you've set appropriate [cache settings][cache-setting]. The default value of `LocMemCache` backend should be okay for simple setups. See Django's [cache documentation][cache-docs] for more details.
...@@ -178,5 +188,6 @@ The following is an example of a rate throttle, that will randomly throttle 1 in ...@@ -178,5 +188,6 @@ The following is an example of a rate throttle, that will randomly throttle 1 in
[cite]: https://dev.twitter.com/docs/error-codes-responses [cite]: https://dev.twitter.com/docs/error-codes-responses
[permissions]: permissions.md [permissions]: permissions.md
[identifing-clients]: http://oxpedia.org/wiki/index.php?title=AppSuite:Grizzly#Multiple_Proxies_in_front_of_the_cluster
[cache-setting]: https://docs.djangoproject.com/en/dev/ref/settings/#caches [cache-setting]: https://docs.djangoproject.com/en/dev/ref/settings/#caches
[cache-docs]: https://docs.djangoproject.com/en/dev/topics/cache/#setting-up-the-cache [cache-docs]: https://docs.djangoproject.com/en/dev/topics/cache/#setting-up-the-cache
...@@ -70,7 +70,7 @@ There are two main advantages of using a `ViewSet` class over using a `View` cla ...@@ -70,7 +70,7 @@ There are two main advantages of using a `ViewSet` class over using a `View` cla
Both of these come with a trade-off. Using regular views and URL confs is more explicit and gives you more control. ViewSets are helpful if you want to get up and running quickly, or when you have a large API and you want to enforce a consistent URL configuration throughout. Both of these come with a trade-off. Using regular views and URL confs is more explicit and gives you more control. ViewSets are helpful if you want to get up and running quickly, or when you have a large API and you want to enforce a consistent URL configuration throughout.
## Marking extra methods for routing ## Marking extra actions for routing
The default routers included with REST framework will provide routes for a standard set of create/retrieve/update/destroy style operations, as shown below: The default routers included with REST framework will provide routes for a standard set of create/retrieve/update/destroy style operations, as shown below:
...@@ -101,14 +101,16 @@ The default routers included with REST framework will provide routes for a stand ...@@ -101,14 +101,16 @@ The default routers included with REST framework will provide routes for a stand
def destroy(self, request, pk=None): def destroy(self, request, pk=None):
pass pass
If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@link` or `@action` decorators. The `@link` decorator will route `GET` requests, and the `@action` decorator will route `POST` requests. If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@detail_route` or `@list_route` decorators.
The `@detail_route` decorator contains `pk` in its URL pattern and is intended for methods which require a single instance. The `@list_route` decorator is intended for methods which operate on a list of objects.
For example: For example:
from django.contrib.auth.models import User from django.contrib.auth.models import User
from rest_framework import viewsets
from rest_framework import status from rest_framework import status
from rest_framework.decorators import action from rest_framework import viewsets
from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response from rest_framework.response import Response
from myapp.serializers import UserSerializer, PasswordSerializer from myapp.serializers import UserSerializer, PasswordSerializer
...@@ -119,7 +121,7 @@ For example: ...@@ -119,7 +121,7 @@ For example:
queryset = User.objects.all() queryset = User.objects.all()
serializer_class = UserSerializer serializer_class = UserSerializer
@action() @detail_route(methods=['post'])
def set_password(self, request, pk=None): def set_password(self, request, pk=None):
user = self.get_object() user = self.get_object()
serializer = PasswordSerializer(data=request.DATA) serializer = PasswordSerializer(data=request.DATA)
...@@ -131,21 +133,27 @@ For example: ...@@ -131,21 +133,27 @@ For example:
return Response(serializer.errors, return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST) status=status.HTTP_400_BAD_REQUEST)
The `@action` and `@link` decorators can additionally take extra arguments that will be set for the routed view only. For example... @list_route()
def recent_users(self, request):
recent_users = User.objects.all().order('-last_login')
page = self.paginate_queryset(recent_users)
serializer = self.get_pagination_serializer(page)
return Response(serializer.data)
The decorators can additionally take extra arguments that will be set for the routed view only. For example...
@action(permission_classes=[IsAdminOrIsSelf]) @detail_route(methods=['post'], permission_classes=[IsAdminOrIsSelf])
def set_password(self, request, pk=None): def set_password(self, request, pk=None):
... ...
The `@action` decorator will route `POST` requests by default, but may also accept other HTTP methods, by using the `method` argument. For example: The `@action` decorator will route `POST` requests by default, but may also accept other HTTP methods, by using the `methods` argument. For example:
@action(methods=['POST', 'DELETE']) @detail_route(methods=['post', 'delete'])
def unset_password(self, request, pk=None): def unset_password(self, request, pk=None):
... ...
The two new actions will then be available at the urls `^users/{pk}/set_password/$` and `^users/{pk}/unset_password/$` The two new actions will then be available at the urls `^users/{pk}/set_password/$` and `^users/{pk}/unset_password/$`
--- ---
# API Reference # API Reference
......
...@@ -206,19 +206,9 @@ General guides to using REST framework. ...@@ -206,19 +206,9 @@ General guides to using REST framework.
## Development ## Development
If you want to work on REST framework itself, clone the repository, then... See the [Contribution guidelines][contributing] for information on how to clone
the repository, run the test suite and contribute changes back to REST
Build the docs: Framework.
./mkdocs.py
Run the tests:
./rest_framework/runtests/runtests.py
To run the tests against all supported configurations, first install [the tox testing tool][tox] globally, using `pip install tox`, then simply run `tox`:
tox
## Support ## Support
......
* Writable nested serializers.
* List/detail routes.
* 1.3 Support dropped, install six for <=1.4.?.
* `allow_none` for char fields
* `trailing_slash = True` --> `[^/]`, `trailing_slash = False` --> `[^/.]`, becomes simply `[^/]` and `lookup_value_regex` is added.
...@@ -65,7 +65,7 @@ To run the tests, clone the repository, and then: ...@@ -65,7 +65,7 @@ To run the tests, clone the repository, and then:
pip install -r optionals.txt pip install -r optionals.txt
# Run the tests # Run the tests
rest_framework/runtests/runtests.py py.test
You can also use the excellent [tox][tox] testing tool to run the tests against all supported versions of Python and Django. Install `tox` globally, and then simply run: You can also use the excellent [tox][tox] testing tool to run the tests against all supported versions of Python and Django. Install `tox` globally, and then simply run:
......
...@@ -38,11 +38,8 @@ You can determine your currently installed version using `pip freeze`: ...@@ -38,11 +38,8 @@ You can determine your currently installed version using `pip freeze`:
--- ---
## 2.3.x series ### 2.4.0
### 2.3.x
**Date**: April 2014
* Added compatibility with Django 1.7's native migrations. * Added compatibility with Django 1.7's native migrations.
**IMPORTANT**: In order to continue to use south with Django <1.7 you **must** provide **IMPORTANT**: In order to continue to use south with Django <1.7 you **must** provide
...@@ -51,21 +48,41 @@ You can determine your currently installed version using `pip freeze`: ...@@ -51,21 +48,41 @@ You can determine your currently installed version using `pip freeze`:
SOUTH_MIGRATION_MODULES = { SOUTH_MIGRATION_MODULES = {
'authtoken': 'rest_framework.authtoken.south_migrations', 'authtoken': 'rest_framework.authtoken.south_migrations',
} }
* Use py.test
* `@detail_route` and `@list_route` decorators replace `@action` and `@link`.
* `six` no longer bundled. For Django <= 1.4.1, install `six` package.
* Support customizable view name and description functions, using the `VIEW_NAME_FUNCTION` and `VIEW_DESCRIPTION_FUNCTION` settings.
* Added `NUM_PROXIES` setting for smarter client IP identification.
* Added `MAX_PAGINATE_BY` setting and `max_paginate_by` generic view attribute.
* Added `cache` attribute to throttles to allow overriding of default cache.
* Bugfix: `?page_size=0` query parameter now falls back to default page size for view, instead of always turning pagination off.
## 2.3.x series
* Fix nested serializers linked through a backward foreign key relation ### 2.3.14
* Fix bad links for the `BrowsableAPIRenderer` with `YAMLRenderer`
* Add `UnicodeYAMLRenderer` that extends `YAMLRenderer` with unicode **Date**: 12th June 2014
* Fix `parse_header` argument convertion
* Fix mediatype detection under Python3 * **Security fix**: Escape request path when it is include as part of the login and logout links in the browsable API.
* Web browseable API now offers blank option on dropdown when the field is not required * `help_text` and `verbose_name` automatically set for related fields on `ModelSerializer`.
* `APIException` representation improved for logging purposes * Fix nested serializers linked through a backward foreign key relation.
* Allow source="*" within nested serializers * Fix bad links for the `BrowsableAPIRenderer` with `YAMLRenderer`.
* Better support for custom oauth2 provider backends * Add `UnicodeYAMLRenderer` that extends `YAMLRenderer` with unicode.
* Fix field validation if it's optional and has no value * Fix `parse_header` argument convertion.
* Add `SEARCH_PARAM` and `ORDERING_PARAM` * Fix mediatype detection under Python 3.
* Fix `APIRequestFactory` to support arguments within the url string for GET * Web browseable API now offers blank option on dropdown when the field is not required.
* Allow three transport modes for access tokens when accessing a protected resource * `APIException` representation improved for logging purposes.
* Fix `Request`'s `QueryDict` encoding * Allow source="*" within nested serializers.
* Better support for custom oauth2 provider backends.
* Fix field validation if it's optional and has no value.
* Add `SEARCH_PARAM` and `ORDERING_PARAM`.
* Fix `APIRequestFactory` to support arguments within the url string for GET.
* Allow three transport modes for access tokens when accessing a protected resource.
* Fix `QueryDict` encoding on request objects.
* Ensure throttle keys do not contain spaces, as those are invalid if using `memcached`.
* Support `blank_display_value` on `ChoiceField`.
### 2.3.13 ### 2.3.13
......
...@@ -25,7 +25,7 @@ Here we've used the `ReadOnlyModelViewSet` class to automatically provide the de ...@@ -25,7 +25,7 @@ Here we've used the `ReadOnlyModelViewSet` class to automatically provide the de
Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class. Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class.
from rest_framework.decorators import link from rest_framework.decorators import detail_route
class SnippetViewSet(viewsets.ModelViewSet): class SnippetViewSet(viewsets.ModelViewSet):
""" """
...@@ -39,7 +39,7 @@ Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighl ...@@ -39,7 +39,7 @@ Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighl
permission_classes = (permissions.IsAuthenticatedOrReadOnly, permission_classes = (permissions.IsAuthenticatedOrReadOnly,
IsOwnerOrReadOnly,) IsOwnerOrReadOnly,)
@link(renderer_classes=[renderers.StaticHTMLRenderer]) @detail_route(renderer_classes=[renderers.StaticHTMLRenderer])
def highlight(self, request, *args, **kwargs): def highlight(self, request, *args, **kwargs):
snippet = self.get_object() snippet = self.get_object()
return Response(snippet.highlighted) return Response(snippet.highlighted)
...@@ -49,9 +49,9 @@ Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighl ...@@ -49,9 +49,9 @@ Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighl
This time we've used the `ModelViewSet` class in order to get the complete set of default read and write operations. This time we've used the `ModelViewSet` class in order to get the complete set of default read and write operations.
Notice that we've also used the `@link` decorator to create a custom action, named `highlight`. This decorator can be used to add any custom endpoints that don't fit into the standard `create`/`update`/`delete` style. Notice that we've also used the `@detail_route` decorator to create a custom action, named `highlight`. This decorator can be used to add any custom endpoints that don't fit into the standard `create`/`update`/`delete` style.
Custom actions which use the `@link` decorator will respond to `GET` requests. We could have instead used the `@action` decorator if we wanted an action that responded to `POST` requests. Custom actions which use the `@detail_route` decorator will respond to `GET` requests. We can use the `methods` argument if we wanted an action that responded to `POST` requests.
## Binding ViewSets to URLs explicitly ## Binding ViewSets to URLs explicitly
......
[pytest]
addopts = --tb=short
...@@ -8,7 +8,7 @@ ______ _____ _____ _____ __ _ ...@@ -8,7 +8,7 @@ ______ _____ _____ _____ __ _
""" """
__title__ = 'Django REST framework' __title__ = 'Django REST framework'
__version__ = '2.3.13' __version__ = '2.3.14'
__author__ = 'Tom Christie' __author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause' __license__ = 'BSD 2-Clause'
__copyright__ = 'Copyright 2011-2014 Tom Christie' __copyright__ = 'Copyright 2011-2014 Tom Christie'
......
...@@ -6,9 +6,9 @@ import base64 ...@@ -6,9 +6,9 @@ import base64
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.middleware.csrf import CsrfViewMiddleware
from django.conf import settings from django.conf import settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
from rest_framework.compat import oauth2_provider, provider_now, check_nonce from rest_framework.compat import oauth2_provider, provider_now, check_nonce
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
......
...@@ -3,13 +3,14 @@ The most important decorator in this module is `@api_view`, which is used ...@@ -3,13 +3,14 @@ The most important decorator in this module is `@api_view`, which is used
for writing function-based views with REST framework. for writing function-based views with REST framework.
There are also various decorators for setting the API policies on function There are also various decorators for setting the API policies on function
based views, as well as the `@action` and `@link` decorators, which are based views, as well as the `@detail_route` and `@list_route` decorators, which are
used to annotate methods on viewsets that should be included by routers. used to annotate methods on viewsets that should be included by routers.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from rest_framework.compat import six from rest_framework.compat import six
from rest_framework.views import APIView from rest_framework.views import APIView
import types import types
import warnings
def api_view(http_method_names): def api_view(http_method_names):
...@@ -107,12 +108,40 @@ def permission_classes(permission_classes): ...@@ -107,12 +108,40 @@ def permission_classes(permission_classes):
return decorator return decorator
def detail_route(methods=['get'], **kwargs):
"""
Used to mark a method on a ViewSet that should be routed for detail requests.
"""
def decorator(func):
func.bind_to_methods = methods
func.detail = True
func.kwargs = kwargs
return func
return decorator
def list_route(methods=['get'], **kwargs):
"""
Used to mark a method on a ViewSet that should be routed for list requests.
"""
def decorator(func):
func.bind_to_methods = methods
func.detail = False
func.kwargs = kwargs
return func
return decorator
# These are now pending deprecation, in favor of `detail_route` and `list_route`.
def link(**kwargs): def link(**kwargs):
""" """
Used to mark a method on a ViewSet that should be routed for GET requests. Used to mark a method on a ViewSet that should be routed for detail GET requests.
""" """
msg = 'link is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func): def decorator(func):
func.bind_to_methods = ['get'] func.bind_to_methods = ['get']
func.detail = True
func.kwargs = kwargs func.kwargs = kwargs
return func return func
return decorator return decorator
...@@ -120,10 +149,13 @@ def link(**kwargs): ...@@ -120,10 +149,13 @@ def link(**kwargs):
def action(methods=['post'], **kwargs): def action(methods=['post'], **kwargs):
""" """
Used to mark a method on a ViewSet that should be routed for POST requests. Used to mark a method on a ViewSet that should be routed for detail POST requests.
""" """
msg = 'action is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func): def decorator(func):
func.bind_to_methods = methods func.bind_to_methods = methods
func.detail = True
func.kwargs = kwargs func.kwargs = kwargs
return func return func
return decorator return decorator
\ No newline at end of file
...@@ -18,12 +18,14 @@ from django.conf import settings ...@@ -18,12 +18,14 @@ from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH from django.db.models.fields import BLANK_CHOICE_DASH
from django.http import QueryDict from django.http import QueryDict
from django.forms import widgets from django.forms import widgets
from django.utils import timezone
from django.utils.encoding import is_protected_type from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from rest_framework import ISO_8601 from rest_framework import ISO_8601
from rest_framework.compat import ( from rest_framework.compat import (
timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text, BytesIO, six, smart_text,
force_text, is_non_str_iterable force_text, is_non_str_iterable
) )
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
...@@ -154,7 +156,12 @@ class Field(object): ...@@ -154,7 +156,12 @@ class Field(object):
def widget_html(self): def widget_html(self):
if not self.widget: if not self.widget:
return '' return ''
return self.widget.render(self._name, self._value)
attrs = {}
if 'id' not in self.widget.attrs:
attrs['id'] = self._name
return self.widget.render(self._name, self._value, attrs=attrs)
def label_tag(self): def label_tag(self):
return '<label for="%s">%s:</label>' % (self._name, self.label) return '<label for="%s">%s:</label>' % (self._name, self.label)
...@@ -260,13 +267,6 @@ class WritableField(Field): ...@@ -260,13 +267,6 @@ class WritableField(Field):
validators=[], error_messages=None, widget=None, validators=[], error_messages=None, widget=None,
default=None, blank=None): default=None, blank=None):
# 'blank' is to be deprecated in favor of 'required'
if blank is not None:
warnings.warn('The `blank` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
DeprecationWarning, stacklevel=2)
required = not(blank)
super(WritableField, self).__init__(source=source, label=label, help_text=help_text) super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
self.read_only = read_only self.read_only = read_only
...@@ -460,8 +460,9 @@ class CharField(WritableField): ...@@ -460,8 +460,9 @@ class CharField(WritableField):
type_label = 'string' type_label = 'string'
form_field_class = forms.CharField form_field_class = forms.CharField
def __init__(self, max_length=None, min_length=None, *args, **kwargs): def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length self.max_length, self.min_length = max_length, min_length
self.allow_none = allow_none
super(CharField, self).__init__(*args, **kwargs) super(CharField, self).__init__(*args, **kwargs)
if min_length is not None: if min_length is not None:
self.validators.append(validators.MinLengthValidator(min_length)) self.validators.append(validators.MinLengthValidator(min_length))
...@@ -469,7 +470,9 @@ class CharField(WritableField): ...@@ -469,7 +470,9 @@ class CharField(WritableField):
self.validators.append(validators.MaxLengthValidator(max_length)) self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value): def from_native(self, value):
if isinstance(value, six.string_types) or value is None: if value is None and not self.allow_none:
return ''
if isinstance(value, six.string_types):
return value return value
return smart_text(value) return smart_text(value)
...@@ -501,7 +504,7 @@ class SlugField(CharField): ...@@ -501,7 +504,7 @@ class SlugField(CharField):
class ChoiceField(WritableField): class ChoiceField(WritableField):
type_name = 'ChoiceField' type_name = 'ChoiceField'
type_label = 'multiple choice' type_label = 'choice'
form_field_class = forms.ChoiceField form_field_class = forms.ChoiceField
widget = widgets.Select widget = widgets.Select
default_error_messages = { default_error_messages = {
...@@ -509,12 +512,16 @@ class ChoiceField(WritableField): ...@@ -509,12 +512,16 @@ class ChoiceField(WritableField):
'the available choices.'), 'the available choices.'),
} }
def __init__(self, choices=(), *args, **kwargs): def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
self.empty = kwargs.pop('empty', '') self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs) super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices self.choices = choices
if not self.required: if not self.required:
self.choices = BLANK_CHOICE_DASH + self.choices if blank_display_value is None:
blank_choice = BLANK_CHOICE_DASH
else:
blank_choice = [('', blank_display_value)]
self.choices = blank_choice + self.choices
def _get_choices(self): def _get_choices(self):
return self._choices return self._choices
...@@ -1018,9 +1025,9 @@ class SerializerMethodField(Field): ...@@ -1018,9 +1025,9 @@ class SerializerMethodField(Field):
A field that gets its value by calling a method on the serializer it's attached to. A field that gets its value by calling a method on the serializer it's attached to.
""" """
def __init__(self, method_name): def __init__(self, method_name, *args, **kwargs):
self.method_name = method_name self.method_name = method_name
super(SerializerMethodField, self).__init__() super(SerializerMethodField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
value = getattr(self.parent, self.method_name)(obj) value = getattr(self.parent, self.method_name)(obj)
......
...@@ -90,8 +90,8 @@ class GenericAPIView(views.APIView): ...@@ -90,8 +90,8 @@ class GenericAPIView(views.APIView):
'view': self 'view': self
} }
def get_serializer(self, instance=None, data=None, def get_serializer(self, instance=None, data=None, files=None, many=False,
files=None, many=False, partial=False): partial=False, allow_add_remove=False):
""" """
Return the serializer instance that should be used for validating and Return the serializer instance that should be used for validating and
deserializing input, and for serializing output. deserializing input, and for serializing output.
...@@ -99,7 +99,9 @@ class GenericAPIView(views.APIView): ...@@ -99,7 +99,9 @@ class GenericAPIView(views.APIView):
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
context = self.get_serializer_context() context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files, return serializer_class(instance, data=data, files=files,
many=many, partial=partial, context=context) many=many, partial=partial,
allow_add_remove=allow_add_remove,
context=context)
def get_pagination_serializer(self, page): def get_pagination_serializer(self, page):
""" """
...@@ -121,11 +123,11 @@ class GenericAPIView(views.APIView): ...@@ -121,11 +123,11 @@ class GenericAPIView(views.APIView):
deprecated_style = False deprecated_style = False
if page_size is not None: if page_size is not None:
warnings.warn('The `page_size` parameter to `paginate_queryset()` ' warnings.warn('The `page_size` parameter to `paginate_queryset()` '
'is due to be deprecated. ' 'is deprecated. '
'Note that the return style of this method is also ' 'Note that the return style of this method is also '
'changed, and will simply return a page object ' 'changed, and will simply return a page object '
'when called without a `page_size` argument.', 'when called without a `page_size` argument.',
PendingDeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
deprecated_style = True deprecated_style = True
else: else:
# Determine the required page size. # Determine the required page size.
...@@ -136,10 +138,10 @@ class GenericAPIView(views.APIView): ...@@ -136,10 +138,10 @@ class GenericAPIView(views.APIView):
if not self.allow_empty: if not self.allow_empty:
warnings.warn( warnings.warn(
'The `allow_empty` parameter is due to be deprecated. ' 'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override ' 'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.', '`get_queryset()` and explicitly raise a 404 on empty querysets.',
PendingDeprecationWarning, stacklevel=2 DeprecationWarning, stacklevel=2
) )
paginator = self.paginator_class(queryset, page_size, paginator = self.paginator_class(queryset, page_size,
...@@ -187,10 +189,10 @@ class GenericAPIView(views.APIView): ...@@ -187,10 +189,10 @@ class GenericAPIView(views.APIView):
if not filter_backends and self.filter_backend: if not filter_backends and self.filter_backend:
warnings.warn( warnings.warn(
'The `filter_backend` attribute and `FILTER_BACKEND` setting ' 'The `filter_backend` attribute and `FILTER_BACKEND` setting '
'are due to be deprecated in favor of a `filter_backends` ' 'are deprecated in favor of a `filter_backends` '
'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
'a *list* of filter backend classes.', 'a *list* of filter backend classes.',
PendingDeprecationWarning, stacklevel=2 DeprecationWarning, stacklevel=2
) )
filter_backends = [self.filter_backend] filter_backends = [self.filter_backend]
return filter_backends return filter_backends
...@@ -211,8 +213,8 @@ class GenericAPIView(views.APIView): ...@@ -211,8 +213,8 @@ class GenericAPIView(views.APIView):
""" """
if queryset is not None: if queryset is not None:
warnings.warn('The `queryset` parameter to `get_paginate_by()` ' warnings.warn('The `queryset` parameter to `get_paginate_by()` '
'is due to be deprecated.', 'is deprecated.',
PendingDeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
if self.paginate_by_param: if self.paginate_by_param:
try: try:
...@@ -295,16 +297,16 @@ class GenericAPIView(views.APIView): ...@@ -295,16 +297,16 @@ class GenericAPIView(views.APIView):
filter_kwargs = {self.lookup_field: lookup} filter_kwargs = {self.lookup_field: lookup}
elif pk is not None and self.lookup_field == 'pk': elif pk is not None and self.lookup_field == 'pk':
warnings.warn( warnings.warn(
'The `pk_url_kwarg` attribute is due to be deprecated. ' 'The `pk_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead', 'Use the `lookup_field` attribute instead',
PendingDeprecationWarning DeprecationWarning
) )
filter_kwargs = {'pk': pk} filter_kwargs = {'pk': pk}
elif slug is not None and self.lookup_field == 'pk': elif slug is not None and self.lookup_field == 'pk':
warnings.warn( warnings.warn(
'The `slug_url_kwarg` attribute is due to be deprecated. ' 'The `slug_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead', 'Use the `lookup_field` attribute instead',
PendingDeprecationWarning DeprecationWarning
) )
filter_kwargs = {self.slug_field: slug} filter_kwargs = {self.slug_field: slug}
else: else:
...@@ -524,9 +526,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, ...@@ -524,9 +526,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
class MultipleObjectAPIView(GenericAPIView): class MultipleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn( warnings.warn(
'Subclassing `MultipleObjectAPIView` is due to be deprecated. ' 'Subclassing `MultipleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.', 'You should simply subclass `GenericAPIView` instead.',
PendingDeprecationWarning, stacklevel=2 DeprecationWarning, stacklevel=2
) )
super(MultipleObjectAPIView, self).__init__(*args, **kwargs) super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
...@@ -534,8 +536,8 @@ class MultipleObjectAPIView(GenericAPIView): ...@@ -534,8 +536,8 @@ class MultipleObjectAPIView(GenericAPIView):
class SingleObjectAPIView(GenericAPIView): class SingleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn( warnings.warn(
'Subclassing `SingleObjectAPIView` is due to be deprecated. ' 'Subclassing `SingleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.', 'You should simply subclass `GenericAPIView` instead.',
PendingDeprecationWarning, stacklevel=2 DeprecationWarning, stacklevel=2
) )
super(SingleObjectAPIView, self).__init__(*args, **kwargs) super(SingleObjectAPIView, self).__init__(*args, **kwargs)
...@@ -26,14 +26,14 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None) ...@@ -26,14 +26,14 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None)
include = [] include = []
if pk: if pk:
# Pending deprecation # Deprecated
pk_field = obj._meta.pk pk_field = obj._meta.pk
while pk_field.rel: while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name) include.append(pk_field.name)
if slug_field: if slug_field:
# Pending deprecation # Deprecated
include.append(slug_field) include.append(slug_field)
if lookup_field and lookup_field != 'pk': if lookup_field and lookup_field != 'pk':
...@@ -79,10 +79,10 @@ class ListModelMixin(object): ...@@ -79,10 +79,10 @@ class ListModelMixin(object):
# `.allow_empty = False`, to raise 404 errors on empty querysets. # `.allow_empty = False`, to raise 404 errors on empty querysets.
if not self.allow_empty and not self.object_list: if not self.allow_empty and not self.object_list:
warnings.warn( warnings.warn(
'The `allow_empty` parameter is due to be deprecated. ' 'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override ' 'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.', '`get_queryset()` and explicitly raise a 404 on empty querysets.',
PendingDeprecationWarning DeprecationWarning
) )
class_name = self.__class__.__name__ class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name} error_msg = self.empty_error % {'class_name': class_name}
......
...@@ -2,15 +2,12 @@ ...@@ -2,15 +2,12 @@
Provides a set of pluggable permission policies. Provides a set of pluggable permission policies.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import inspect
import warnings
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
from django.http import Http404 from django.http import Http404
from rest_framework.compat import (get_model_name, oauth2_provider_scope, from rest_framework.compat import (get_model_name, oauth2_provider_scope,
oauth2_constants) oauth2_constants)
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
class BasePermission(object): class BasePermission(object):
""" """
...@@ -27,13 +24,6 @@ class BasePermission(object): ...@@ -27,13 +24,6 @@ class BasePermission(object):
""" """
Return `True` if permission is granted, `False` otherwise. Return `True` if permission is granted, `False` otherwise.
""" """
if len(inspect.getargspec(self.has_permission).args) == 4:
warnings.warn(
'The `obj` argument in `has_permission` is deprecated. '
'Use `has_object_permission()` instead for object permissions.',
DeprecationWarning, stacklevel=2
)
return self.has_permission(request, view, obj)
return True return True
......
...@@ -41,14 +41,6 @@ class RelatedField(WritableField): ...@@ -41,14 +41,6 @@ class RelatedField(WritableField):
many = False many = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# 'null' is to be deprecated in favor of 'required'
if 'null' in kwargs:
warnings.warn('The `null` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
DeprecationWarning, stacklevel=2)
kwargs['required'] = not kwargs.pop('null')
queryset = kwargs.pop('queryset', None) queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', self.many) self.many = kwargs.pop('many', self.many)
if self.many: if self.many:
...@@ -330,7 +322,7 @@ class HyperlinkedRelatedField(RelatedField): ...@@ -330,7 +322,7 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected url string, received %s.'), 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
} }
# These are all pending deprecation # These are all deprecated
pk_url_kwarg = 'pk' pk_url_kwarg = 'pk'
slug_field = 'slug' slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
...@@ -344,16 +336,16 @@ class HyperlinkedRelatedField(RelatedField): ...@@ -344,16 +336,16 @@ class HyperlinkedRelatedField(RelatedField):
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.format = kwargs.pop('format', None) self.format = kwargs.pop('format', None)
# These are pending deprecation # These are deprecated
if 'pk_url_kwarg' in kwargs: if 'pk_url_kwarg' in kwargs:
msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_url_kwarg' in kwargs: if 'slug_url_kwarg' in kwargs:
msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_field' in kwargs: if 'slug_field' in kwargs:
msg = 'slug_field is pending deprecation. Use lookup_field instead.' msg = 'slug_field is deprecated. Use lookup_field instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_field = kwargs.pop('slug_field', self.slug_field) self.slug_field = kwargs.pop('slug_field', self.slug_field)
...@@ -396,9 +388,9 @@ class HyperlinkedRelatedField(RelatedField): ...@@ -396,9 +388,9 @@ class HyperlinkedRelatedField(RelatedField):
# If the lookup succeeds using the default slug params, # If the lookup succeeds using the default slug params,
# then `slug_field` is being used implicitly, and we # then `slug_field` is being used implicitly, and we
# we need to warn about the pending deprecation. # we need to warn about the pending deprecation.
msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \ msg = 'Implicit slug field hyperlinked fields are deprecated.' \
'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
return ret return ret
except NoReverseMatch: except NoReverseMatch:
pass pass
...@@ -432,14 +424,11 @@ class HyperlinkedRelatedField(RelatedField): ...@@ -432,14 +424,11 @@ class HyperlinkedRelatedField(RelatedField):
request = self.context.get('request', None) request = self.context.get('request', None)
format = self.format or self.context.get('format', None) format = self.format or self.context.get('format', None)
if request is None: assert request is not None, (
msg = ( "`HyperlinkedRelatedField` requires the request in the serializer "
"Using `HyperlinkedRelatedField` without including the request " "context. Add `context={'request': request}` when instantiating "
"in the serializer context is deprecated. " "the serializer."
"Add `context={'request': request}` when instantiating " )
"the serializer."
)
warnings.warn(msg, DeprecationWarning, stacklevel=4)
# If the object has not yet been saved then we cannot hyperlink to it. # If the object has not yet been saved then we cannot hyperlink to it.
if getattr(obj, 'pk', None) is None: if getattr(obj, 'pk', None) is None:
...@@ -499,7 +488,7 @@ class HyperlinkedIdentityField(Field): ...@@ -499,7 +488,7 @@ class HyperlinkedIdentityField(Field):
lookup_field = 'pk' lookup_field = 'pk'
read_only = True read_only = True
# These are all pending deprecation # These are all deprecated
pk_url_kwarg = 'pk' pk_url_kwarg = 'pk'
slug_field = 'slug' slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
...@@ -515,16 +504,16 @@ class HyperlinkedIdentityField(Field): ...@@ -515,16 +504,16 @@ class HyperlinkedIdentityField(Field):
lookup_field = kwargs.pop('lookup_field', None) lookup_field = kwargs.pop('lookup_field', None)
self.lookup_field = lookup_field or self.lookup_field self.lookup_field = lookup_field or self.lookup_field
# These are pending deprecation # These are deprecated
if 'pk_url_kwarg' in kwargs: if 'pk_url_kwarg' in kwargs:
msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_url_kwarg' in kwargs: if 'slug_url_kwarg' in kwargs:
msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_field' in kwargs: if 'slug_field' in kwargs:
msg = 'slug_field is pending deprecation. Use lookup_field instead.' msg = 'slug_field is deprecated. Use lookup_field instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.slug_field = kwargs.pop('slug_field', self.slug_field) self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field default_slug_kwarg = self.slug_url_kwarg or self.slug_field
...@@ -538,11 +527,11 @@ class HyperlinkedIdentityField(Field): ...@@ -538,11 +527,11 @@ class HyperlinkedIdentityField(Field):
format = self.context.get('format', None) format = self.context.get('format', None)
view_name = self.view_name view_name = self.view_name
if request is None: assert request is not None, (
warnings.warn("Using `HyperlinkedIdentityField` without including the " "`HyperlinkedIdentityField` requires the request in the serializer"
"request in the serializer context is deprecated. " " context. Add `context={'request': request}` when instantiating "
"Add `context={'request': request}` when instantiating the serializer.", "the serializer."
DeprecationWarning, stacklevel=4) )
# By default use whatever format is given for the current context # By default use whatever format is given for the current context
# unless the target is a different type to the source. # unless the target is a different type to the source.
...@@ -606,41 +595,3 @@ class HyperlinkedIdentityField(Field): ...@@ -606,41 +595,3 @@ class HyperlinkedIdentityField(Field):
pass pass
raise NoReverseMatch() raise NoReverseMatch()
### Old-style many classes for backwards compat
class ManyRelatedField(RelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManyRelatedField()` is deprecated. '
'Use `RelatedField(many=True)` instead.',
DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyRelatedField, self).__init__(*args, **kwargs)
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. '
'Use `PrimaryKeyRelatedField(many=True)` instead.',
DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
class ManySlugRelatedField(SlugRelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManySlugRelatedField()` is deprecated. '
'Use `SlugRelatedField(many=True)` instead.',
DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManySlugRelatedField, self).__init__(*args, **kwargs)
class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. '
'Use `HyperlinkedRelatedField(many=True)` instead.',
DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)
...@@ -17,15 +17,17 @@ from __future__ import unicode_literals ...@@ -17,15 +17,17 @@ from __future__ import unicode_literals
import itertools import itertools
from collections import namedtuple from collections import namedtuple
from django.conf.urls import patterns, url
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework import views from rest_framework import views
from rest_framework.compat import patterns, url
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs'])
DynamicDetailRoute = namedtuple('DynamicDetailRoute', ['url', 'name', 'initkwargs'])
DynamicListRoute = namedtuple('DynamicListRoute', ['url', 'name', 'initkwargs'])
def replace_methodname(format_string, methodname): def replace_methodname(format_string, methodname):
...@@ -88,6 +90,14 @@ class SimpleRouter(BaseRouter): ...@@ -88,6 +90,14 @@ class SimpleRouter(BaseRouter):
name='{basename}-list', name='{basename}-list',
initkwargs={'suffix': 'List'} initkwargs={'suffix': 'List'}
), ),
# Dynamically generated list routes.
# Generated using @list_route decorator
# on methods of the viewset.
DynamicListRoute(
url=r'^{prefix}/{methodname}{trailing_slash}$',
name='{basename}-{methodnamehyphen}',
initkwargs={}
),
# Detail route. # Detail route.
Route( Route(
url=r'^{prefix}/{lookup}{trailing_slash}$', url=r'^{prefix}/{lookup}{trailing_slash}$',
...@@ -100,13 +110,10 @@ class SimpleRouter(BaseRouter): ...@@ -100,13 +110,10 @@ class SimpleRouter(BaseRouter):
name='{basename}-detail', name='{basename}-detail',
initkwargs={'suffix': 'Instance'} initkwargs={'suffix': 'Instance'}
), ),
# Dynamically generated routes. # Dynamically generated detail routes.
# Generated using @action or @link decorators on methods of the viewset. # Generated using @detail_route decorator on methods of the viewset.
Route( DynamicDetailRoute(
url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$', url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
mapping={
'{httpmethod}': '{methodname}',
},
name='{basename}-{methodnamehyphen}', name='{basename}-{methodnamehyphen}',
initkwargs={} initkwargs={}
), ),
...@@ -139,25 +146,42 @@ class SimpleRouter(BaseRouter): ...@@ -139,25 +146,42 @@ class SimpleRouter(BaseRouter):
Returns a list of the Route namedtuple. Returns a list of the Route namedtuple.
""" """
known_actions = flatten([route.mapping.values() for route in self.routes]) known_actions = flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)])
# Determine any `@action` or `@link` decorated methods on the viewset # Determine any `@detail_route` or `@list_route` decorated methods on the viewset
dynamic_routes = [] detail_routes = []
list_routes = []
for methodname in dir(viewset): for methodname in dir(viewset):
attr = getattr(viewset, methodname) attr = getattr(viewset, methodname)
httpmethods = getattr(attr, 'bind_to_methods', None) httpmethods = getattr(attr, 'bind_to_methods', None)
detail = getattr(attr, 'detail', True)
if httpmethods: if httpmethods:
if methodname in known_actions: if methodname in known_actions:
raise ImproperlyConfigured('Cannot use @action or @link decorator on ' raise ImproperlyConfigured('Cannot use @detail_route or @list_route '
'method "%s" as it is an existing route' % methodname) 'decorators on method "%s" '
'as it is an existing route' % methodname)
httpmethods = [method.lower() for method in httpmethods] httpmethods = [method.lower() for method in httpmethods]
dynamic_routes.append((httpmethods, methodname)) if detail:
detail_routes.append((httpmethods, methodname))
else:
list_routes.append((httpmethods, methodname))
ret = [] ret = []
for route in self.routes: for route in self.routes:
if route.mapping == {'{httpmethod}': '{methodname}'}: if isinstance(route, DynamicDetailRoute):
# Dynamic routes (@link or @action decorator) # Dynamic detail routes (@detail_route decorator)
for httpmethods, methodname in dynamic_routes: for httpmethods, methodname in detail_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
url=replace_methodname(route.url, methodname),
mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
name=replace_methodname(route.name, methodname),
initkwargs=initkwargs,
))
elif isinstance(route, DynamicListRoute):
# Dynamic list routes (@list_route decorator)
for httpmethods, methodname in list_routes:
initkwargs = route.initkwargs.copy() initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs) initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route( ret.append(Route(
...@@ -195,13 +219,16 @@ class SimpleRouter(BaseRouter): ...@@ -195,13 +219,16 @@ class SimpleRouter(BaseRouter):
https://github.com/alanjds/drf-nested-routers https://github.com/alanjds/drf-nested-routers
""" """
if self.trailing_slash: base_regex = '(?P<{lookup_prefix}{lookup_field}>{lookup_value})'
base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/]+)' # Use `pk` as default field, unset set. Default regex should not
else: # consume `.json` style suffixes and should break at '/' boundaries.
# Don't consume `.json` style suffixes
base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/.]+)'
lookup_field = getattr(viewset, 'lookup_field', 'pk') lookup_field = getattr(viewset, 'lookup_field', 'pk')
return base_regex.format(lookup_field=lookup_field, lookup_prefix=lookup_prefix) lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
return base_regex.format(
lookup_prefix=lookup_prefix,
lookup_field=lookup_field,
lookup_value=lookup_value
)
def get_urls(self): def get_urls(self):
""" """
......
#!/usr/bin/env python
"""
Useful tool to run the test suite for rest_framework and generate a coverage report.
"""
# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
# fix sys path so we don't need to setup PYTHONPATH
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
from coverage import coverage
def main():
"""Run the tests for rest_framework and generate a coverage report."""
cov = coverage()
cov.erase()
cov.start()
from django.conf import settings
from django.test.utils import get_runner
TestRunner = get_runner(settings)
if hasattr(TestRunner, 'func_name'):
# Pre 1.2 test runners were just functions,
# and did not support the 'failfast' option.
import warnings
warnings.warn(
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
failures = TestRunner(['tests'])
else:
test_runner = TestRunner()
failures = test_runner.run_tests(['tests'])
cov.stop()
# Discover the list of all modules that we should test coverage for
import rest_framework
project_dir = os.path.dirname(rest_framework.__file__)
cov_files = []
for (path, dirs, files) in os.walk(project_dir):
# Drop tests and runtests directories from the test coverage report
if os.path.basename(path) in ['tests', 'runtests', 'migrations']:
continue
# Drop the compat and six modules from coverage, since we're not interested in the coverage
# of modules which are specifically for resolving environment dependant imports.
# (Because we'll end up getting different coverage reports for it for each environment)
if 'compat.py' in files:
files.remove('compat.py')
if 'six.py' in files:
files.remove('six.py')
# Same applies to template tags module.
# This module has to include branching on Django versions,
# so it's never possible for it to have full coverage.
if 'rest_framework.py' in files:
files.remove('rest_framework.py')
cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')])
cov.report(cov_files)
if '--html' in sys.argv:
cov.html_report(cov_files, directory='coverage')
sys.exit(failures)
if __name__ == '__main__':
main()
#!/usr/bin/env python
# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
# fix sys path so we don't need to setup PYTHONPATH
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
import django
from django.conf import settings
from django.test.utils import get_runner
def usage():
return """
Usage: python runtests.py [UnitTestClass].[method]
You can pass the Class name of the `UnitTestClass` you want to test.
Append a method name if you only want to test a specific method of that class.
"""
def main():
try:
django.setup()
except AttributeError:
pass
TestRunner = get_runner(settings)
test_runner = TestRunner()
if len(sys.argv) == 2:
test_case = '.' + sys.argv[1]
elif len(sys.argv) == 1:
test_case = ''
else:
print(usage())
sys.exit(1)
test_module_name = 'rest_framework.tests'
if django.VERSION[0] == 1 and django.VERSION[1] < 6:
test_module_name = 'tests'
failures = test_runner.run_tests([test_module_name + test_case])
sys.exit(failures)
if __name__ == '__main__':
main()
"""
Blank URLConf just to keep runtests.py happy.
"""
from rest_framework.compat import patterns
urlpatterns = patterns('',
)
...@@ -21,7 +21,8 @@ from django.core.paginator import Page ...@@ -21,7 +21,8 @@ from django.core.paginator import Page
from django.db import models from django.db import models
from django.forms import widgets from django.forms import widgets
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model, six from django.core.exceptions import ObjectDoesNotExist
from rest_framework.compat import six
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
...@@ -32,8 +33,8 @@ from rest_framework.settings import api_settings ...@@ -32,8 +33,8 @@ from rest_framework.settings import api_settings
# This helps keep the separation between model fields, form fields, and # This helps keep the separation between model fields, form fields, and
# serializer fields more explicit. # serializer fields more explicit.
from rest_framework.relations import * from rest_framework.relations import * # NOQA
from rest_framework.fields import * from rest_framework.fields import * # NOQA
def _resolve_model(obj): def _resolve_model(obj):
...@@ -48,7 +49,7 @@ def _resolve_model(obj): ...@@ -48,7 +49,7 @@ def _resolve_model(obj):
String representations should have the format: String representations should have the format:
'appname.ModelName' 'appname.ModelName'
""" """
if type(obj) == str and len(obj.split('.')) == 2: if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.') app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name) return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model): elif inspect.isclass(obj) and issubclass(obj, models.Model):
...@@ -181,7 +182,7 @@ class BaseSerializer(WritableField): ...@@ -181,7 +182,7 @@ class BaseSerializer(WritableField):
_dict_class = SortedDictWithMetadata _dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None, def __init__(self, instance=None, data=None, files=None,
context=None, partial=False, many=None, context=None, partial=False, many=False,
allow_add_remove=False, **kwargs): allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs) super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
...@@ -344,7 +345,7 @@ class BaseSerializer(WritableField): ...@@ -344,7 +345,7 @@ class BaseSerializer(WritableField):
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
if field.read_only and obj is None: if field.read_only and obj is None:
continue continue
field.initialize(parent=self, field_name=field_name) field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name) key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name) value = field.field_to_native(obj, field_name)
...@@ -411,12 +412,7 @@ class BaseSerializer(WritableField): ...@@ -411,12 +412,7 @@ class BaseSerializer(WritableField):
if value is None: if value is None:
return None return None
if self.many is not None: if self.many:
many = self.many
else:
many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type))
if many:
return [self.to_native(item) for item in value] return [self.to_native(item) for item in value]
return self.to_native(value) return self.to_native(value)
...@@ -662,7 +658,7 @@ class ModelSerializer(Serializer): ...@@ -662,7 +658,7 @@ class ModelSerializer(Serializer):
cls = self.opts.model cls = self.opts.model
assert cls is not None, \ assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta opts = cls._meta.concrete_model._meta
ret = SortedDict() ret = SortedDict()
nested = bool(self.opts.depth) nested = bool(self.opts.depth)
...@@ -695,10 +691,10 @@ class ModelSerializer(Serializer): ...@@ -695,10 +691,10 @@ class ModelSerializer(Serializer):
if len(inspect.getargspec(self.get_nested_field).args) == 2: if len(inspect.getargspec(self.get_nested_field).args) == 2:
warnings.warn( warnings.warn(
'The `get_nested_field(model_field)` call signature ' 'The `get_nested_field(model_field)` call signature '
'is due to be deprecated. ' 'is deprecated. '
'Use `get_nested_field(model_field, related_model, ' 'Use `get_nested_field(model_field, related_model, '
'to_many) instead', 'to_many) instead',
PendingDeprecationWarning DeprecationWarning
) )
field = self.get_nested_field(model_field) field = self.get_nested_field(model_field)
else: else:
...@@ -707,10 +703,10 @@ class ModelSerializer(Serializer): ...@@ -707,10 +703,10 @@ class ModelSerializer(Serializer):
if len(inspect.getargspec(self.get_nested_field).args) == 3: if len(inspect.getargspec(self.get_nested_field).args) == 3:
warnings.warn( warnings.warn(
'The `get_related_field(model_field, to_many)` call ' 'The `get_related_field(model_field, to_many)` call '
'signature is due to be deprecated. ' 'signature is deprecated. '
'Use `get_related_field(model_field, related_model, ' 'Use `get_related_field(model_field, related_model, '
'to_many) instead', 'to_many) instead',
PendingDeprecationWarning DeprecationWarning
) )
field = self.get_related_field(model_field, to_many=to_many) field = self.get_related_field(model_field, to_many=to_many)
else: else:
...@@ -758,9 +754,9 @@ class ModelSerializer(Serializer): ...@@ -758,9 +754,9 @@ class ModelSerializer(Serializer):
field.read_only = True field.read_only = True
ret[accessor_name] = field ret[accessor_name] = field
# Ensure that 'read_only_fields' is an iterable # Ensure that 'read_only_fields' is an iterable
assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
# Add the `read_only` flag to any fields that have been specified # Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option # in the `read_only_fields` option
...@@ -775,10 +771,10 @@ class ModelSerializer(Serializer): ...@@ -775,10 +771,10 @@ class ModelSerializer(Serializer):
"on serializer '%s'." % "on serializer '%s'." %
(field_name, self.__class__.__name__)) (field_name, self.__class__.__name__))
ret[field_name].read_only = True ret[field_name].read_only = True
# Ensure that 'write_only_fields' is an iterable # Ensure that 'write_only_fields' is an iterable
assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
for field_name in self.opts.write_only_fields: for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), ( assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in " "field '%s' on serializer '%s' specified in "
...@@ -789,7 +785,7 @@ class ModelSerializer(Serializer): ...@@ -789,7 +785,7 @@ class ModelSerializer(Serializer):
"Non-existant field '%s' specified in `write_only_fields` " "Non-existant field '%s' specified in `write_only_fields` "
"on serializer '%s'." % "on serializer '%s'." %
(field_name, self.__class__.__name__)) (field_name, self.__class__.__name__))
ret[field_name].write_only = True ret[field_name].write_only = True
return ret return ret
...@@ -880,6 +876,10 @@ class ModelSerializer(Serializer): ...@@ -880,6 +876,10 @@ class ModelSerializer(Serializer):
issubclass(model_field.__class__, models.PositiveSmallIntegerField): issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0 kwargs['min_value'] = 0
if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
attribute_dict = { attribute_dict = {
models.CharField: ['max_length'], models.CharField: ['max_length'],
models.CommaSeparatedIntegerField: ['max_length'], models.CommaSeparatedIntegerField: ['max_length'],
...@@ -906,7 +906,7 @@ class ModelSerializer(Serializer): ...@@ -906,7 +906,7 @@ class ModelSerializer(Serializer):
Return a list of field names to exclude from model validation. Return a list of field names to exclude from model validation.
""" """
cls = self.opts.model cls = self.opts.model
opts = get_concrete_model(cls)._meta opts = cls._meta.concrete_model._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many] exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
......
...@@ -63,6 +63,7 @@ DEFAULTS = { ...@@ -63,6 +63,7 @@ DEFAULTS = {
'user': None, 'user': None,
'anon': None, 'anon': None,
}, },
'NUM_PROXIES': None,
# Pagination # Pagination
'PAGINATE_BY': None, 'PAGINATE_BY': None,
......
{% load url from future %} {% load url from future %}
{% load staticfiles %}
{% load rest_framework %} {% load rest_framework %}
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
......
{% load url from future %} {% load url from future %}
{% load staticfiles %}
{% load rest_framework %} {% load rest_framework %}
<html> <html>
......
...@@ -5,95 +5,13 @@ from django.http import QueryDict ...@@ -5,95 +5,13 @@ from django.http import QueryDict
from django.utils.encoding import iri_to_uri from django.utils.encoding import iri_to_uri
from django.utils.html import escape from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe from django.utils.safestring import SafeData, mark_safe
from rest_framework.compat import urlparse, force_text, six, smart_urlquote from rest_framework.compat import urlparse, force_text, six
from django.utils.html import smart_urlquote
import re import re
register = template.Library() register = template.Library()
# Note we don't use 'load staticfiles', because we need a 1.3 compatible
# version, so instead we include the `static` template tag ourselves.
# When 1.3 becomes unsupported by REST framework, we can instead start to
# use the {% load staticfiles %} tag, remove the following code,
# and add a dependency that `django.contrib.staticfiles` must be installed.
# Note: We can't put this into the `compat` module because the compat import
# from rest_framework.compat import ...
# conflicts with this rest_framework template tag module.
try: # Django 1.5+
from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode
@register.tag('static')
def do_static(parser, token):
return StaticFilesNode.handle_token(parser, token)
except ImportError:
try: # Django 1.4
from django.contrib.staticfiles.storage import staticfiles_storage
@register.simple_tag
def static(path):
"""
A template tag that returns the URL to a file
using staticfiles' storage backend
"""
return staticfiles_storage.url(path)
except ImportError: # Django 1.3
from urlparse import urljoin
from django import template
from django.templatetags.static import PrefixNode
class StaticNode(template.Node):
def __init__(self, varname=None, path=None):
if path is None:
raise template.TemplateSyntaxError(
"Static template nodes must be given a path to return.")
self.path = path
self.varname = varname
def url(self, context):
path = self.path.resolve(context)
return self.handle_simple(path)
def render(self, context):
url = self.url(context)
if self.varname is None:
return url
context[self.varname] = url
return ''
@classmethod
def handle_simple(cls, path):
return urljoin(PrefixNode.handle_simple("STATIC_URL"), path)
@classmethod
def handle_token(cls, parser, token):
"""
Class method to parse prefix node and return a Node.
"""
bits = token.split_contents()
if len(bits) < 2:
raise template.TemplateSyntaxError(
"'%s' takes at least one argument (path to file)" % bits[0])
path = parser.compile_filter(bits[1])
if len(bits) >= 2 and bits[-2] == 'as':
varname = bits[3]
else:
varname = None
return cls(varname, path)
@register.tag('static')
def do_static_13(parser, token):
return StaticNode.handle_token(parser, token)
def replace_query_param(url, key, val): def replace_query_param(url, key, val):
""" """
Given a URL and a key/val pair, set or replace an item in the query Given a URL and a key/val pair, set or replace an item in the query
...@@ -122,7 +40,7 @@ def optional_login(request): ...@@ -122,7 +40,7 @@ def optional_login(request):
except NoReverseMatch: except NoReverseMatch:
return '' return ''
snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path) snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, escape(request.path))
return snippet return snippet
...@@ -136,7 +54,7 @@ def optional_logout(request): ...@@ -136,7 +54,7 @@ def optional_logout(request):
except NoReverseMatch: except NoReverseMatch:
return '' return ''
snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path) snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, escape(request.path))
return snippet return snippet
......
...@@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory): ...@@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory):
""" """
if not data: if not data:
return ('', None) return ('', content_type)
assert format is None or content_type is None, ( assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.' 'You may not set both `format` and `content_type`.'
......
"""
Force import of all modules in this package in order to get the standard test
runner to pick up the tests. Yowzers.
"""
from __future__ import unicode_literals
import os
import django
modules = [filename.rsplit('.', 1)[0]
for filename in os.listdir(os.path.dirname(__file__))
if filename.endswith('.py') and not filename.startswith('_')]
__test__ = dict()
if django.VERSION < (1, 6):
for module in modules:
exec("from rest_framework.tests.%s import *" % module)
...@@ -18,6 +18,25 @@ class BaseThrottle(object): ...@@ -18,6 +18,25 @@ class BaseThrottle(object):
""" """
raise NotImplementedError('.allow_request() must be overridden') raise NotImplementedError('.allow_request() must be overridden')
def get_ident(self, request):
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(xff))]
return client_addr.strip()
return xff if xff else remote_addr
def wait(self): def wait(self):
""" """
Optionally, return a recommended number of seconds to wait before Optionally, return a recommended number of seconds to wait before
...@@ -162,7 +181,7 @@ class AnonRateThrottle(SimpleRateThrottle): ...@@ -162,7 +181,7 @@ class AnonRateThrottle(SimpleRateThrottle):
return self.cache_format % { return self.cache_format % {
'scope': self.scope, 'scope': self.scope,
'ident': ident 'ident': self.get_ident(request)
} }
...@@ -180,7 +199,7 @@ class UserRateThrottle(SimpleRateThrottle): ...@@ -180,7 +199,7 @@ class UserRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated(): if request.user.is_authenticated():
ident = request.user.id ident = request.user.id
else: else:
ident = request.META.get('REMOTE_ADDR', None) ident = self.get_ident(request)
return self.cache_format % { return self.cache_format % {
'scope': self.scope, 'scope': self.scope,
...@@ -228,7 +247,7 @@ class ScopedRateThrottle(SimpleRateThrottle): ...@@ -228,7 +247,7 @@ class ScopedRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated(): if request.user.is_authenticated():
ident = request.user.id ident = request.user.id
else: else:
ident = request.META.get('REMOTE_ADDR', None) ident = self.get_ident(request)
return self.cache_format % { return self.cache_format % {
'scope': self.scope, 'scope': self.scope,
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import url, include
from django.core.urlresolvers import RegexURLResolver from django.core.urlresolvers import RegexURLResolver
from rest_framework.compat import url, include
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
......
...@@ -13,7 +13,7 @@ your authentication settings include `SessionAuthentication`. ...@@ -13,7 +13,7 @@ your authentication settings include `SessionAuthentication`.
) )
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from rest_framework.compat import patterns, url from django.conf.urls import patterns, url
template_name = {'template_name': 'rest_framework/login.html'} template_name = {'template_name': 'rest_framework/login.html'}
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
Helper classes for parsers. Helper classes for parsers.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.utils import timezone
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.functional import Promise from django.utils.functional import Promise
from rest_framework.compat import timezone, force_text from rest_framework.compat import force_text
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime import datetime
import decimal import decimal
......
...@@ -2,11 +2,26 @@ ...@@ -2,11 +2,26 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from setuptools import setup from setuptools import setup
from setuptools.command.test import test as TestCommand
import re import re
import os import os
import sys import sys
# This command has been borrowed from
# https://github.com/getsentry/sentry/blob/master/setup.py
class PyTest(TestCommand):
def finalize_options(self):
TestCommand.finalize_options(self)
self.test_args = ['tests']
self.test_suite = True
def run_tests(self):
import pytest
errno = pytest.main(self.test_args)
sys.exit(errno)
def get_version(package): def get_version(package):
""" """
Return package version as listed in `__version__` in `init.py`. Return package version as listed in `__version__` in `init.py`.
...@@ -62,7 +77,7 @@ setup( ...@@ -62,7 +77,7 @@ setup(
author_email='tom@tomchristie.com', # SEE NOTE BELOW (*) author_email='tom@tomchristie.com', # SEE NOTE BELOW (*)
packages=get_packages('rest_framework'), packages=get_packages('rest_framework'),
package_data=get_package_data('rest_framework'), package_data=get_package_data('rest_framework'),
test_suite='rest_framework.runtests.runtests.main', cmdclass={'test': PyTest},
install_requires=[], install_requires=[],
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',
......
from django.db import models from django.db import models
from rest_framework.tests.users.models import User from tests.users.models import User
class Account(models.Model): class Account(models.Model):
......
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.accounts.models import Account from tests.accounts.models import Account
from rest_framework.tests.users.serializers import UserSerializer from tests.users.serializers import UserSerializer
class AccountSerializer(serializers.ModelSerializer): class AccountSerializer(serializers.ModelSerializer):
......
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import NullableForeignKeySource from tests.models import NullableForeignKeySource
class NullableFKSourceSerializer(serializers.ModelSerializer): class NullableFKSourceSerializer(serializers.ModelSerializer):
......
...@@ -79,7 +79,7 @@ MIDDLEWARE_CLASSES = ( ...@@ -79,7 +79,7 @@ MIDDLEWARE_CLASSES = (
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
) )
ROOT_URLCONF = 'urls' ROOT_URLCONF = 'tests.urls'
TEMPLATE_DIRS = ( TEMPLATE_DIRS = (
# Put strings here, like "/home/html/django_templates" or "C:/www/django/templates". # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates".
...@@ -93,16 +93,13 @@ INSTALLED_APPS = ( ...@@ -93,16 +93,13 @@ INSTALLED_APPS = (
'django.contrib.sessions', 'django.contrib.sessions',
'django.contrib.sites', 'django.contrib.sites',
'django.contrib.messages', 'django.contrib.messages',
# Uncomment the next line to enable the admin: 'django.contrib.staticfiles',
# 'django.contrib.admin',
# Uncomment the next line to enable admin documentation:
# 'django.contrib.admindocs',
'rest_framework', 'rest_framework',
'rest_framework.authtoken', 'rest_framework.authtoken',
'rest_framework.tests', 'tests',
'rest_framework.tests.accounts', 'tests.accounts',
'rest_framework.tests.records', 'tests.records',
'rest_framework.tests.users', 'tests.users',
) )
# OAuth is optional and won't work if there is no oauth_provider & oauth2 # OAuth is optional and won't work if there is no oauth_provider & oauth2
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url, include
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import TestCase from django.test import TestCase
...@@ -19,7 +20,7 @@ from rest_framework.authentication import ( ...@@ -19,7 +20,7 @@ from rest_framework.authentication import (
OAuth2Authentication OAuth2Authentication
) )
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.compat import patterns, url, include, six from rest_framework.compat import six
from rest_framework.compat import oauth2_provider, oauth2_provider_scope from rest_framework.compat import oauth2_provider, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory, APIClient from rest_framework.test import APIRequestFactory, APIClient
...@@ -69,7 +70,7 @@ if oauth2_provider is not None: ...@@ -69,7 +70,7 @@ if oauth2_provider is not None:
class BasicAuthTests(TestCase): class BasicAuthTests(TestCase):
"""Basic authentication""" """Basic authentication"""
urls = 'rest_framework.tests.test_authentication' urls = 'tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
...@@ -108,7 +109,7 @@ class BasicAuthTests(TestCase): ...@@ -108,7 +109,7 @@ class BasicAuthTests(TestCase):
class SessionAuthTests(TestCase): class SessionAuthTests(TestCase):
"""User session authentication""" """User session authentication"""
urls = 'rest_framework.tests.test_authentication' urls = 'tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
...@@ -155,7 +156,7 @@ class SessionAuthTests(TestCase): ...@@ -155,7 +156,7 @@ class SessionAuthTests(TestCase):
class TokenAuthTests(TestCase): class TokenAuthTests(TestCase):
"""Token authentication""" """Token authentication"""
urls = 'rest_framework.tests.test_authentication' urls = 'tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
...@@ -255,7 +256,7 @@ class IncorrectCredentialsTests(TestCase): ...@@ -255,7 +256,7 @@ class IncorrectCredentialsTests(TestCase):
class OAuthTests(TestCase): class OAuthTests(TestCase):
"""OAuth 1.0a authentication""" """OAuth 1.0a authentication"""
urls = 'rest_framework.tests.test_authentication' urls = 'tests.test_authentication'
def setUp(self): def setUp(self):
# these imports are here because oauth is optional and hiding them in try..except block or compat # these imports are here because oauth is optional and hiding them in try..except block or compat
...@@ -485,7 +486,7 @@ class OAuthTests(TestCase): ...@@ -485,7 +486,7 @@ class OAuthTests(TestCase):
class OAuth2Tests(TestCase): class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication""" """OAuth 2.0 authentication"""
urls = 'rest_framework.tests.test_authentication' urls = 'tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.views import APIView from rest_framework.views import APIView
...@@ -36,7 +36,7 @@ urlpatterns = patterns('', ...@@ -36,7 +36,7 @@ urlpatterns = patterns('',
class BreadcrumbTests(TestCase): class BreadcrumbTests(TestCase):
"""Tests the breadcrumb functionality used by the HTML renderer.""" """Tests the breadcrumb functionality used by the HTML renderer."""
urls = 'rest_framework.tests.test_breadcrumbs' urls = 'tests.test_breadcrumbs'
def test_root_breadcrumbs(self): def test_root_breadcrumbs(self):
url = '/' url = '/'
......
...@@ -4,8 +4,8 @@ from __future__ import unicode_literals ...@@ -4,8 +4,8 @@ from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import apply_markdown, smart_text from rest_framework.compat import apply_markdown, smart_text
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring from .description import ViewWithNonASCIICharactersInDocstring
from rest_framework.tests.description import UTF8_TEST_DOCSTRING from .description import UTF8_TEST_DOCSTRING
# We check that docstrings get nicely un-indented. # We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring DESCRIPTION = """an example docstring
......
...@@ -4,6 +4,7 @@ General serializer field tests. ...@@ -4,6 +4,7 @@ General serializer field tests.
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
import re
from decimal import Decimal from decimal import Decimal
from uuid import uuid4 from uuid import uuid4
from django.core import validators from django.core import validators
...@@ -11,7 +12,7 @@ from django.db import models ...@@ -11,7 +12,7 @@ from django.db import models
from django.test import TestCase from django.test import TestCase
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import RESTFrameworkModel from tests.models import RESTFrameworkModel
class TimestampedModel(models.Model): class TimestampedModel(models.Model):
...@@ -103,6 +104,16 @@ class BasicFieldTests(TestCase): ...@@ -103,6 +104,16 @@ class BasicFieldTests(TestCase):
keys = list(field.to_native(ret).keys()) keys = list(field.to_native(ret).keys())
self.assertEqual(keys, ['c', 'b', 'a', 'z']) self.assertEqual(keys, ['c', 'b', 'a', 'z'])
def test_widget_html_attributes(self):
"""
Make sure widget_html() renders the correct attributes
"""
r = re.compile('(\S+)=["\']?((?:.(?!["\']?\s+(?:\S+)=|[>"\']))+.)["\']?')
form = TimeFieldModelSerializer().data
attributes = r.findall(form.fields['clock'].widget_html())
self.assertIn(('name', 'clock'), attributes)
self.assertIn(('id', 'clock'), attributes)
class DateFieldTest(TestCase): class DateFieldTest(TestCase):
""" """
...@@ -706,6 +717,15 @@ class ChoiceFieldTests(TestCase): ...@@ -706,6 +717,15 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES) self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
def test_blank_choice_display(self):
blank = 'No Preference'
f = serializers.ChoiceField(
required=False,
choices=SAMPLE_CHOICES,
blank_display_value=blank,
)
self.assertEqual(f.choices, [('', blank)] + SAMPLE_CHOICES)
def test_invalid_choice_model(self): def test_invalid_choice_model(self):
s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'}) s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
self.assertFalse(s.is_valid()) self.assertFalse(s.is_valid())
......
...@@ -5,12 +5,11 @@ from django.db import models ...@@ -5,12 +5,11 @@ from django.db import models
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from django.conf.urls import patterns, url
from rest_framework import generics, serializers, status, filters from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url from rest_framework.compat import django_filters
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from .models import FilterableItem, BasicModel
from .models import FilterableItem
from .utils import temporary_setting from .utils import temporary_setting
factory = APIRequestFactory() factory = APIRequestFactory()
...@@ -243,7 +242,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): ...@@ -243,7 +242,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
""" """
Integration tests for filtered detail views. Integration tests for filtered detail views.
""" """
urls = 'rest_framework.tests.test_filters' urls = 'tests.test_filters'
def _get_url(self, item): def _get_url(self, item):
return reverse('detail-view', kwargs=dict(pk=item.pk)) return reverse('detail-view', kwargs=dict(pk=item.pk))
......
...@@ -84,7 +84,7 @@ class TestGenericRelations(TestCase): ...@@ -84,7 +84,7 @@ class TestGenericRelations(TestCase):
exclude = ('content_type', 'object_id') exclude = ('content_type', 'object_id')
class BookmarkSerializer(serializers.ModelSerializer): class BookmarkSerializer(serializers.ModelSerializer):
tags = TagSerializer() tags = TagSerializer(many=True)
class Meta: class Meta:
model = Bookmark model = Bookmark
......
...@@ -4,8 +4,8 @@ from django.shortcuts import get_object_or_404 ...@@ -4,8 +4,8 @@ from django.shortcuts import get_object_or_404
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, renderers, serializers, status from rest_framework import generics, renderers, serializers, status
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.tests.models import ForeignKeySource, ForeignKeyTarget from tests.models import ForeignKeySource, ForeignKeyTarget
from rest_framework.compat import six from rest_framework.compat import six
factory = APIRequestFactory() factory = APIRequestFactory()
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.conf.urls import patterns, url
from django.http import Http404 from django.http import Http404
from django.test import TestCase from django.test import TestCase
from django.template import TemplateDoesNotExist, Template from django.template import TemplateDoesNotExist, Template
import django.template.loader import django.template.loader
from rest_framework import status from rest_framework import status
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view, renderer_classes from rest_framework.decorators import api_view, renderer_classes
from rest_framework.renderers import TemplateHTMLRenderer from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response from rest_framework.response import Response
...@@ -42,7 +42,7 @@ urlpatterns = patterns('', ...@@ -42,7 +42,7 @@ urlpatterns = patterns('',
class TemplateHTMLRendererTests(TestCase): class TemplateHTMLRendererTests(TestCase):
urls = 'rest_framework.tests.test_htmlrenderer' urls = 'tests.test_htmlrenderer'
def setUp(self): def setUp(self):
""" """
...@@ -82,7 +82,7 @@ class TemplateHTMLRendererTests(TestCase): ...@@ -82,7 +82,7 @@ class TemplateHTMLRendererTests(TestCase):
class TemplateHTMLRendererExceptionTests(TestCase): class TemplateHTMLRendererExceptionTests(TestCase):
urls = 'rest_framework.tests.test_htmlrenderer' urls = 'tests.test_htmlrenderer'
def setUp(self): def setUp(self):
""" """
......
...@@ -2,10 +2,10 @@ from __future__ import unicode_literals ...@@ -2,10 +2,10 @@ from __future__ import unicode_literals
import json import json
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, status, serializers from rest_framework import generics, status, serializers
from rest_framework.compat import patterns, url from django.conf.urls import patterns, url
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import ( from tests.models import (
Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
Album, Photo, OptionalRelationModel Album, Photo, OptionalRelationModel
) )
...@@ -25,7 +25,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer): ...@@ -25,7 +25,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer):
class PhotoSerializer(serializers.Serializer): class PhotoSerializer(serializers.Serializer):
description = serializers.CharField() description = serializers.CharField()
album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title') album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title')
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
return Photo(**attrs) return Photo(**attrs)
...@@ -110,7 +110,7 @@ urlpatterns = patterns('', ...@@ -110,7 +110,7 @@ urlpatterns = patterns('',
class TestBasicHyperlinkedView(TestCase): class TestBasicHyperlinkedView(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
""" """
...@@ -147,7 +147,7 @@ class TestBasicHyperlinkedView(TestCase): ...@@ -147,7 +147,7 @@ class TestBasicHyperlinkedView(TestCase):
class TestManyToManyHyperlinkedView(TestCase): class TestManyToManyHyperlinkedView(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
""" """
...@@ -195,7 +195,7 @@ class TestManyToManyHyperlinkedView(TestCase): ...@@ -195,7 +195,7 @@ class TestManyToManyHyperlinkedView(TestCase):
class TestHyperlinkedIdentityFieldLookup(TestCase): class TestHyperlinkedIdentityFieldLookup(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
""" """
...@@ -225,7 +225,7 @@ class TestHyperlinkedIdentityFieldLookup(TestCase): ...@@ -225,7 +225,7 @@ class TestHyperlinkedIdentityFieldLookup(TestCase):
class TestCreateWithForeignKeys(TestCase): class TestCreateWithForeignKeys(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
""" """
...@@ -250,7 +250,7 @@ class TestCreateWithForeignKeys(TestCase): ...@@ -250,7 +250,7 @@ class TestCreateWithForeignKeys(TestCase):
class TestCreateWithForeignKeysAndCustomSlug(TestCase): class TestCreateWithForeignKeysAndCustomSlug(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
""" """
...@@ -275,7 +275,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): ...@@ -275,7 +275,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
class TestOptionalRelationHyperlinkedView(TestCase): class TestOptionalRelationHyperlinkedView(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
""" """
...@@ -335,7 +335,7 @@ class TestOverriddenURLField(TestCase): ...@@ -335,7 +335,7 @@ class TestOverriddenURLField(TestCase):
class TestURLFieldNameBySettings(TestCase): class TestURLFieldNameBySettings(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
self.saved_url_field_name = api_settings.URL_FIELD_NAME self.saved_url_field_name = api_settings.URL_FIELD_NAME
...@@ -360,7 +360,7 @@ class TestURLFieldNameBySettings(TestCase): ...@@ -360,7 +360,7 @@ class TestURLFieldNameBySettings(TestCase):
class TestURLFieldNameByOptions(TestCase): class TestURLFieldNameByOptions(TestCase):
urls = 'rest_framework.tests.test_hyperlinkedserializers' urls = 'tests.test_hyperlinkedserializers'
def setUp(self): def setUp(self):
class Serializer(serializers.HyperlinkedModelSerializer): class Serializer(serializers.HyperlinkedModelSerializer):
......
...@@ -2,7 +2,7 @@ from __future__ import unicode_literals ...@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import RESTFrameworkModel from tests.models import RESTFrameworkModel
# Models # Models
......
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from rest_framework.compat import patterns, url from django.conf.urls import patterns, url
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from rest_framework.tests.models import NullableForeignKeySource from tests.models import NullableForeignKeySource
from rest_framework.tests.serializers import NullableFKSourceSerializer from tests.serializers import NullableFKSourceSerializer
from rest_framework.tests.views import NullableFKSourceDetail from tests.views import NullableFKSourceDetail
urlpatterns = patterns( urlpatterns = patterns(
...@@ -18,7 +18,7 @@ class NullableForeignKeyTests(APITestCase): ...@@ -18,7 +18,7 @@ class NullableForeignKeyTests(APITestCase):
DRF should be able to handle nullable foreign keys when a test DRF should be able to handle nullable foreign keys when a test
Client POST/PUT request is made with its own serialized object. Client POST/PUT request is made with its own serialized object.
""" """
urls = 'rest_framework.tests.test_nullable_fields' urls = 'tests.test_nullable_fields'
def test_updating_object_with_null_fk(self): def test_updating_object_with_null_fk(self):
obj = NullableForeignKeySource(name='example', target=None) obj = NullableForeignKeySource(name='example', target=None)
......
...@@ -8,8 +8,7 @@ from django.utils import unittest ...@@ -8,8 +8,7 @@ from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from .models import BasicModel, FilterableItem
from .models import FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
......
...@@ -7,7 +7,7 @@ from rest_framework import generics, status, permissions, authentication, HTTP_H ...@@ -7,7 +7,7 @@ from rest_framework import generics, status, permissions, authentication, HTTP_H
from rest_framework.compat import guardian, get_model_name from rest_framework.compat import guardian, get_model_name
from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.filters import DjangoObjectPermissionsFilter
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from tests.models import BasicModel
import base64 import base64
factory = APIRequestFactory() factory = APIRequestFactory()
...@@ -187,8 +187,7 @@ class ObjectPermissionsIntegrationTests(TestCase): ...@@ -187,8 +187,7 @@ class ObjectPermissionsIntegrationTests(TestCase):
""" """
Integration tests for the object level permissions API. Integration tests for the object level permissions API.
""" """
@classmethod def setUp(self):
def setUpClass(cls):
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
# create users # create users
...@@ -215,21 +214,13 @@ class ObjectPermissionsIntegrationTests(TestCase): ...@@ -215,21 +214,13 @@ class ObjectPermissionsIntegrationTests(TestCase):
assign_perm(perm, everyone) assign_perm(perm, everyone)
everyone.user_set.add(*users.values()) everyone.user_set.add(*users.values())
cls.perms = perms
cls.users = users
def setUp(self):
from guardian.shortcuts import assign_perm
perms = self.perms
users = self.users
# appropriate object level permissions # appropriate object level permissions
readers = Group.objects.create(name='readers') readers = Group.objects.create(name='readers')
writers = Group.objects.create(name='writers') writers = Group.objects.create(name='writers')
deleters = Group.objects.create(name='deleters') deleters = Group.objects.create(name='deleters')
model = BasicPermModel.objects.create(text='foo') model = BasicPermModel.objects.create(text='foo')
assign_perm(perms['view'], readers, model) assign_perm(perms['view'], readers, model)
assign_perm(perms['change'], writers, model) assign_perm(perms['change'], writers, model)
assign_perm(perms['delete'], deleters, model) assign_perm(perms['delete'], deleters, model)
......
...@@ -7,7 +7,7 @@ from django.db import models ...@@ -7,7 +7,7 @@ from django.db import models
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import BlogPost from tests.models import BlogPost
class NullModel(models.Model): class NullModel(models.Model):
...@@ -107,7 +107,7 @@ class RelatedFieldSourceTests(TestCase): ...@@ -107,7 +107,7 @@ class RelatedFieldSourceTests(TestCase):
Check that the exception message are correct if the source field Check that the exception message are correct if the source field
doesn't exist. doesn't exist.
""" """
from rest_framework.tests.models import ManyToManySource from tests.models import ManyToManySource
class Meta: class Meta:
model = ManyToManySource model = ManyToManySource
attrs = { attrs = {
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import patterns, url
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import ( from tests.models import (
BlogPost, BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
...@@ -71,7 +71,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): ...@@ -71,7 +71,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
class HyperlinkedManyToManyTests(TestCase): class HyperlinkedManyToManyTests(TestCase):
urls = 'rest_framework.tests.test_relations_hyperlink' urls = 'tests.test_relations_hyperlink'
def setUp(self): def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
...@@ -179,7 +179,7 @@ class HyperlinkedManyToManyTests(TestCase): ...@@ -179,7 +179,7 @@ class HyperlinkedManyToManyTests(TestCase):
class HyperlinkedForeignKeyTests(TestCase): class HyperlinkedForeignKeyTests(TestCase):
urls = 'rest_framework.tests.test_relations_hyperlink' urls = 'tests.test_relations_hyperlink'
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
...@@ -307,7 +307,7 @@ class HyperlinkedForeignKeyTests(TestCase): ...@@ -307,7 +307,7 @@ class HyperlinkedForeignKeyTests(TestCase):
class HyperlinkedNullableForeignKeyTests(TestCase): class HyperlinkedNullableForeignKeyTests(TestCase):
urls = 'rest_framework.tests.test_relations_hyperlink' urls = 'tests.test_relations_hyperlink'
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
...@@ -435,7 +435,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase): ...@@ -435,7 +435,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
class HyperlinkedNullableOneToOneTests(TestCase): class HyperlinkedNullableOneToOneTests(TestCase):
urls = 'rest_framework.tests.test_relations_hyperlink' urls = 'tests.test_relations_hyperlink'
def setUp(self): def setUp(self):
target = OneToOneTarget(name='target-1') target = OneToOneTarget(name='target-1')
...@@ -458,7 +458,7 @@ class HyperlinkedNullableOneToOneTests(TestCase): ...@@ -458,7 +458,7 @@ class HyperlinkedNullableOneToOneTests(TestCase):
# Regression tests for #694 (`source` attribute on related fields) # Regression tests for #694 (`source` attribute on related fields)
class HyperlinkedRelatedFieldSourceTests(TestCase): class HyperlinkedRelatedFieldSourceTests(TestCase):
urls = 'rest_framework.tests.test_relations_hyperlink' urls = 'tests.test_relations_hyperlink'
def test_related_manager_source(self): def test_related_manager_source(self):
""" """
......
...@@ -2,7 +2,7 @@ from __future__ import unicode_literals ...@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import ( from tests.models import (
BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
) )
......
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
class ForeignKeyTargetSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer):
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from decimal import Decimal from decimal import Decimal
from django.conf.urls import patterns, url, include
from django.core.cache import cache from django.core.cache import cache
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions from rest_framework import status, permissions
from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO from rest_framework.compat import yaml, etree, six, StringIO
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
...@@ -152,7 +153,7 @@ class RendererEndToEndTests(TestCase): ...@@ -152,7 +153,7 @@ class RendererEndToEndTests(TestCase):
End-to-end testing of renderers using an RendererMixin on a generic view. End-to-end testing of renderers using an RendererMixin on a generic view.
""" """
urls = 'rest_framework.tests.test_renderers' urls = 'tests.test_renderers'
def test_default_renderer_serializes_content(self): def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
...@@ -387,7 +388,7 @@ class JSONPRendererTests(TestCase): ...@@ -387,7 +388,7 @@ class JSONPRendererTests(TestCase):
Tests specific to the JSONP Renderer Tests specific to the JSONP Renderer
""" """
urls = 'rest_framework.tests.test_renderers' urls = 'tests.test_renderers'
def test_without_callback_with_json_renderer(self): def test_without_callback_with_json_renderer(self):
""" """
...@@ -582,7 +583,7 @@ class CacheRenderTest(TestCase): ...@@ -582,7 +583,7 @@ class CacheRenderTest(TestCase):
Tests specific to caching responses Tests specific to caching responses
""" """
urls = 'rest_framework.tests.test_renderers' urls = 'tests.test_renderers'
cache_key = 'just_a_cache_key' cache_key = 'just_a_cache_key'
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Tests for content parsing, and form-overloaded content parsing. Tests for content parsing, and form-overloaded content parsing.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sessions.middleware import SessionMiddleware
...@@ -9,7 +10,6 @@ from django.core.handlers.wsgi import WSGIRequest ...@@ -9,7 +10,6 @@ from django.core.handlers.wsgi import WSGIRequest
from django.test import TestCase from django.test import TestCase
from rest_framework import status from rest_framework import status
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import patterns
from rest_framework.parsers import ( from rest_framework.parsers import (
BaseParser, BaseParser,
FormParser, FormParser,
...@@ -278,7 +278,7 @@ urlpatterns = patterns('', ...@@ -278,7 +278,7 @@ urlpatterns = patterns('',
class TestContentParsingWithAuthentication(TestCase): class TestContentParsingWithAuthentication(TestCase):
urls = 'rest_framework.tests.test_request' urls = 'tests.test_request'
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url, include
from django.test import TestCase from django.test import TestCase
from rest_framework.tests.models import BasicModel, BasicModelSerializer from tests.models import BasicModel, BasicModelSerializer
from rest_framework.compat import patterns, url, include
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework import generics from rest_framework import generics
...@@ -118,7 +118,7 @@ class RendererIntegrationTests(TestCase): ...@@ -118,7 +118,7 @@ class RendererIntegrationTests(TestCase):
End-to-end testing of renderers using an ResponseMixin on a generic view. End-to-end testing of renderers using an ResponseMixin on a generic view.
""" """
urls = 'rest_framework.tests.test_response' urls = 'tests.test_response'
def test_default_renderer_serializes_content(self): def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
...@@ -198,7 +198,7 @@ class Issue122Tests(TestCase): ...@@ -198,7 +198,7 @@ class Issue122Tests(TestCase):
""" """
Tests that covers #122. Tests that covers #122.
""" """
urls = 'rest_framework.tests.test_response' urls = 'tests.test_response'
def test_only_html_renderer(self): def test_only_html_renderer(self):
""" """
...@@ -218,7 +218,7 @@ class Issue467Tests(TestCase): ...@@ -218,7 +218,7 @@ class Issue467Tests(TestCase):
Tests for #467 Tests for #467
""" """
urls = 'rest_framework.tests.test_response' urls = 'tests.test_response'
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
...@@ -232,7 +232,7 @@ class Issue807Tests(TestCase): ...@@ -232,7 +232,7 @@ class Issue807Tests(TestCase):
Covers #807 Covers #807
""" """
urls = 'rest_framework.tests.test_response' urls = 'tests.test_response'
def test_does_not_append_charset_by_default(self): def test_does_not_append_charset_by_default(self):
""" """
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
...@@ -19,7 +19,7 @@ class ReverseTests(TestCase): ...@@ -19,7 +19,7 @@ class ReverseTests(TestCase):
""" """
Tests for fully qualified URLs when using `reverse`. Tests for fully qualified URLs when using `reverse`.
""" """
urls = 'rest_framework.tests.test_reverse' urls = 'tests.test_reverse'
def test_reversed_urls_are_fully_qualified(self): def test_reversed_urls_are_fully_qualified(self):
request = factory.get('/view') request = factory.get('/view')
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url, include
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, permissions from rest_framework import serializers, viewsets, permissions
from rest_framework.compat import include, patterns, url from rest_framework.decorators import detail_route, list_route
from rest_framework.decorators import link, action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter from rest_framework.routers import SimpleRouter, DefaultRouter
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
...@@ -18,23 +18,23 @@ class BasicViewSet(viewsets.ViewSet): ...@@ -18,23 +18,23 @@ class BasicViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
return Response({'method': 'list'}) return Response({'method': 'list'})
@action() @detail_route(methods=['post'])
def action1(self, request, *args, **kwargs): def action1(self, request, *args, **kwargs):
return Response({'method': 'action1'}) return Response({'method': 'action1'})
@action() @detail_route(methods=['post'])
def action2(self, request, *args, **kwargs): def action2(self, request, *args, **kwargs):
return Response({'method': 'action2'}) return Response({'method': 'action2'})
@action(methods=['post', 'delete']) @detail_route(methods=['post', 'delete'])
def action3(self, request, *args, **kwargs): def action3(self, request, *args, **kwargs):
return Response({'method': 'action2'}) return Response({'method': 'action2'})
@link() @detail_route()
def link1(self, request, *args, **kwargs): def link1(self, request, *args, **kwargs):
return Response({'method': 'link1'}) return Response({'method': 'link1'})
@link() @detail_route()
def link2(self, request, *args, **kwargs): def link2(self, request, *args, **kwargs):
return Response({'method': 'link2'}) return Response({'method': 'link2'})
...@@ -72,7 +72,7 @@ class TestCustomLookupFields(TestCase): ...@@ -72,7 +72,7 @@ class TestCustomLookupFields(TestCase):
""" """
Ensure that custom lookup fields are correctly routed. Ensure that custom lookup fields are correctly routed.
""" """
urls = 'rest_framework.tests.test_routers' urls = 'tests.test_routers'
def setUp(self): def setUp(self):
class NoteSerializer(serializers.HyperlinkedModelSerializer): class NoteSerializer(serializers.HyperlinkedModelSerializer):
...@@ -91,7 +91,7 @@ class TestCustomLookupFields(TestCase): ...@@ -91,7 +91,7 @@ class TestCustomLookupFields(TestCase):
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes', NoteViewSet)
from rest_framework.tests import test_routers from tests import test_routers
urls = getattr(test_routers, 'urlpatterns') urls = getattr(test_routers, 'urlpatterns')
urls += patterns('', urls += patterns('',
url(r'^', include(self.router.urls)), url(r'^', include(self.router.urls)),
...@@ -121,6 +121,27 @@ class TestCustomLookupFields(TestCase): ...@@ -121,6 +121,27 @@ class TestCustomLookupFields(TestCase):
) )
class TestLookupValueRegex(TestCase):
"""
Ensure the router honors lookup_value_regex when applied
to the viewset.
"""
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f]{32}'
self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet)
self.urls = self.router.urls
def test_urls_limited_by_lookup_value_regex(self):
expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
class TestTrailingSlashIncluded(TestCase): class TestTrailingSlashIncluded(TestCase):
def setUp(self): def setUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
...@@ -131,7 +152,7 @@ class TestTrailingSlashIncluded(TestCase): ...@@ -131,7 +152,7 @@ class TestTrailingSlashIncluded(TestCase):
self.urls = self.router.urls self.urls = self.router.urls
def test_urls_have_trailing_slash_by_default(self): def test_urls_have_trailing_slash_by_default(self):
expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$'] expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$']
for idx in range(len(expected)): for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern) self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
...@@ -175,7 +196,7 @@ class TestActionKeywordArgs(TestCase): ...@@ -175,7 +196,7 @@ class TestActionKeywordArgs(TestCase):
class TestViewSet(viewsets.ModelViewSet): class TestViewSet(viewsets.ModelViewSet):
permission_classes = [] permission_classes = []
@action(permission_classes=[permissions.AllowAny]) @detail_route(methods=['post'], permission_classes=[permissions.AllowAny])
def custom(self, request, *args, **kwargs): def custom(self, request, *args, **kwargs):
return Response({ return Response({
'permission_classes': self.permission_classes 'permission_classes': self.permission_classes
...@@ -196,14 +217,14 @@ class TestActionKeywordArgs(TestCase): ...@@ -196,14 +217,14 @@ class TestActionKeywordArgs(TestCase):
class TestActionAppliedToExistingRoute(TestCase): class TestActionAppliedToExistingRoute(TestCase):
""" """
Ensure `@action` decorator raises an except when applied Ensure `@detail_route` decorator raises an except when applied
to an existing route to an existing route
""" """
def test_exception_raised_when_action_applied_to_existing_route(self): def test_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet): class TestViewSet(viewsets.ModelViewSet):
@action() @detail_route(methods=['post'])
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
return Response({ return Response({
'hello': 'world' 'hello': 'world'
...@@ -214,3 +235,49 @@ class TestActionAppliedToExistingRoute(TestCase): ...@@ -214,3 +235,49 @@ class TestActionAppliedToExistingRoute(TestCase):
with self.assertRaises(ImproperlyConfigured): with self.assertRaises(ImproperlyConfigured):
self.router.urls self.router.urls
class DynamicListAndDetailViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
return Response({'method': 'list'})
@list_route(methods=['post'])
def list_route_post(self, request, *args, **kwargs):
return Response({'method': 'action1'})
@detail_route(methods=['post'])
def detail_route_post(self, request, *args, **kwargs):
return Response({'method': 'action2'})
@list_route()
def list_route_get(self, request, *args, **kwargs):
return Response({'method': 'link1'})
@detail_route()
def detail_route_get(self, request, *args, **kwargs):
return Response({'method': 'link2'})
class TestDynamicListAndDetailRouter(TestCase):
def setUp(self):
self.router = SimpleRouter()
def test_list_and_detail_route_decorators(self):
routes = self.router.get_routes(DynamicListAndDetailViewSet)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
# Make sure all these endpoints exist and none have been clobbered
for i, endpoint in enumerate(['list_route_get', 'list_route_post', 'detail_route_get', 'detail_route_post']):
route = decorator_routes[i]
# check url listing
if endpoint.startswith('list_'):
self.assertEqual(route.url,
'^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint))
else:
self.assertEqual(route.url,
'^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
# check method to function mapping
if endpoint.endswith('_post'):
method_map = 'post'
else:
method_map = 'get'
self.assertEqual(route.mapping[method_map], endpoint)
...@@ -7,11 +7,11 @@ from django.utils import unittest ...@@ -7,11 +7,11 @@ from django.utils import unittest
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers, fields, relations from rest_framework import serializers, fields, relations
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel,
ForeignKeySource, ManyToManySource) ForeignKeySource, ManyToManySource)
from rest_framework.tests.models import BasicModelSerializer from tests.models import BasicModelSerializer
import datetime import datetime
import pickle import pickle
try: try:
...@@ -30,6 +30,7 @@ if PIL is not None: ...@@ -30,6 +30,7 @@ if PIL is not None:
image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) image_field = models.ImageField(upload_to='test', max_length=1024, blank=True)
slug_field = models.SlugField(max_length=1024, blank=True) slug_field = models.SlugField(max_length=1024, blank=True)
url_field = models.URLField(max_length=1024, blank=True) url_field = models.URLField(max_length=1024, blank=True)
nullable_char_field = models.CharField(max_length=1024, blank=True, null=True)
class DVOAFModel(RESTFrameworkModel): class DVOAFModel(RESTFrameworkModel):
positive_integer_field = models.PositiveIntegerField(blank=True) positive_integer_field = models.PositiveIntegerField(blank=True)
...@@ -660,7 +661,7 @@ class ModelValidationTests(TestCase): ...@@ -660,7 +661,7 @@ class ModelValidationTests(TestCase):
second_serializer = AlbumsSerializer(data={'title': 'a'}) second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid()) self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],}) self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],})
third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}]) third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True)
self.assertFalse(third_serializer.is_valid()) self.assertFalse(third_serializer.is_valid())
self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}]) self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
...@@ -1152,7 +1153,7 @@ class RelatedTraversalTest(TestCase): ...@@ -1152,7 +1153,7 @@ class RelatedTraversalTest(TestCase):
""" """
If a component of the dotted.source is None, return None for the field. If a component of the dotted.source is None, return None for the field.
""" """
from rest_framework.tests.models import NullableForeignKeySource from tests.models import NullableForeignKeySource
instance = NullableForeignKeySource.objects.create(name='Source with null FK') instance = NullableForeignKeySource.objects.create(name='Source with null FK')
class NullableSourceSerializer(serializers.Serializer): class NullableSourceSerializer(serializers.Serializer):
...@@ -1257,6 +1258,20 @@ class BlankFieldTests(TestCase): ...@@ -1257,6 +1258,20 @@ class BlankFieldTests(TestCase):
serializer = self.model_serializer_class(data={}) serializer = self.model_serializer_class(data={})
self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.is_valid(), True)
def test_create_model_null_field_save(self):
"""
Regression test for #1330.
https://github.com/tomchristie/django-rest-framework/pull/1330
"""
serializer = self.model_serializer_class(data={'title': None})
self.assertEqual(serializer.is_valid(), True)
try:
serializer.save()
except Exception:
self.fail('Exception raised on save() after validation passes')
#test for issue #460 #test for issue #460
class SerializerPickleTests(TestCase): class SerializerPickleTests(TestCase):
...@@ -1491,7 +1506,7 @@ class NestedSerializerContextTests(TestCase): ...@@ -1491,7 +1506,7 @@ class NestedSerializerContextTests(TestCase):
model = Album model = Album
fields = ("photo_set", "callable") fields = ("photo_set", "callable")
photo_set = PhotoSerializer(source="photo_set") photo_set = PhotoSerializer(source="photo_set", many=True)
callable = serializers.SerializerMethodField("_callable") callable = serializers.SerializerMethodField("_callable")
def _callable(self, instance): def _callable(self, instance):
...@@ -1503,7 +1518,7 @@ class NestedSerializerContextTests(TestCase): ...@@ -1503,7 +1518,7 @@ class NestedSerializerContextTests(TestCase):
albums = None albums = None
class AlbumCollectionSerializer(serializers.Serializer): class AlbumCollectionSerializer(serializers.Serializer):
albums = AlbumSerializer(source="albums") albums = AlbumSerializer(source="albums", many=True)
album1 = Album.objects.create(title="album 1") album1 = Album.objects.create(title="album 1")
album2 = Album.objects.create(title="album 2") album2 = Album.objects.create(title="album 2")
...@@ -1660,6 +1675,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase): ...@@ -1660,6 +1675,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
'url_field': [ 'url_field': [
('max_length', 1024), ('max_length', 1024),
], ],
'nullable_char_field': [
('max_length', 1024),
('allow_none', True),
],
} }
def field_test(self, field): def field_test(self, field):
...@@ -1696,6 +1715,9 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase): ...@@ -1696,6 +1715,9 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
def test_url_field(self): def test_url_field(self):
self.field_test('url_field') self.field_test('url_field')
def test_nullable_char_field(self):
self.field_test('nullable_char_field')
@unittest.skipUnless(PIL is not None, 'PIL is not installed') @unittest.skipUnless(PIL is not None, 'PIL is not installed')
class DefaultValuesOnAutogeneratedFieldsTests(TestCase): class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
......
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.accounts.serializers import AccountSerializer from tests.accounts.serializers import AccountSerializer
class ImportingModelSerializerTests(TestCase): class ImportingModelSerializerTests(TestCase):
......
...@@ -2,7 +2,8 @@ from django.db import models ...@@ -2,7 +2,8 @@ from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework.serializers import _resolve_model from rest_framework.serializers import _resolve_model
from rest_framework.tests.models import BasicModel from tests.models import BasicModel
from rest_framework.compat import six
class ResolveModelTests(TestCase): class ResolveModelTests(TestCase):
...@@ -19,6 +20,10 @@ class ResolveModelTests(TestCase): ...@@ -19,6 +20,10 @@ class ResolveModelTests(TestCase):
resolved_model = _resolve_model('tests.BasicModel') resolved_model = _resolve_model('tests.BasicModel')
self.assertEqual(resolved_model, BasicModel) self.assertEqual(resolved_model, BasicModel)
def test_resolve_unicode_representation(self):
resolved_model = _resolve_model(six.text_type('tests.BasicModel'))
self.assertEqual(resolved_model, BasicModel)
def test_resolve_non_django_model(self): def test_resolve_non_django_model(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_resolve_model(TestCase) _resolve_model(TestCase)
......
...@@ -10,13 +10,13 @@ class TestSettings(TestCase): ...@@ -10,13 +10,13 @@ class TestSettings(TestCase):
def test_non_import_errors(self): def test_non_import_errors(self):
"""Make sure other errors aren't suppressed.""" """Make sure other errors aren't suppressed."""
settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
settings.DEFAULT_MODEL_SERIALIZER_CLASS settings.DEFAULT_MODEL_SERIALIZER_CLASS
def test_import_error_message_maintained(self): def test_import_error_message_maintained(self):
"""Make sure real import errors are captured and raised sensibly.""" """Make sure real import errors are captured and raised sensibly."""
settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
with self.assertRaises(ImportError) as cm: with self.assertRaises(ImportError) as cm:
settings.DEFAULT_MODEL_SERIALIZER_CLASS settings.DEFAULT_MODEL_SERIALIZER_CLASS
self.assertTrue('ImportError' in str(cm.exception)) self.assertTrue('ImportError' in str(cm.exception))
# -- coding: utf-8 -- # -- coding: utf-8 --
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url
from io import BytesIO from io import BytesIO
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory, force_authenticate from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
...@@ -35,7 +35,7 @@ urlpatterns = patterns('', ...@@ -35,7 +35,7 @@ urlpatterns = patterns('',
class TestAPITestClient(TestCase): class TestAPITestClient(TestCase):
urls = 'rest_framework.tests.test_testing' urls = 'tests.test_testing'
def setUp(self): def setUp(self):
self.client = APIClient() self.client = APIClient()
......
...@@ -5,6 +5,7 @@ from __future__ import unicode_literals ...@@ -5,6 +5,7 @@ from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.cache import cache from django.core.cache import cache
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
...@@ -275,3 +276,68 @@ class ScopedRateThrottleTests(TestCase): ...@@ -275,3 +276,68 @@ class ScopedRateThrottleTests(TestCase):
self.increment_timer() self.increment_timer()
response = self.unscoped_view(request) response = self.unscoped_view(request)
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
class XffTestingBase(TestCase):
def setUp(self):
class Throttle(ScopedRateThrottle):
THROTTLE_RATES = {'test_limit': '1/day'}
TIMER_SECONDS = 0
timer = lambda self: self.TIMER_SECONDS
class View(APIView):
throttle_classes = (Throttle,)
throttle_scope = 'test_limit'
def get(self, request):
return Response('test_limit')
cache.clear()
self.throttle = Throttle()
self.view = View.as_view()
self.request = APIRequestFactory().get('/some_uri')
self.request.META['REMOTE_ADDR'] = '3.3.3.3'
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
def config_proxy(self, num_proxies):
setattr(api_settings, 'NUM_PROXIES', num_proxies)
class IdWithXffBasicTests(XffTestingBase):
def test_accepts_request_under_limit(self):
self.config_proxy(0)
self.assertEqual(200, self.view(self.request).status_code)
def test_denies_request_over_limit(self):
self.config_proxy(0)
self.view(self.request)
self.assertEqual(429, self.view(self.request).status_code)
class XffSpoofingTests(XffTestingBase):
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
self.config_proxy(1)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
self.assertEqual(429, self.view(self.request).status_code)
def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
self.config_proxy(2)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
self.assertEqual(429, self.view(self.request).status_code)
class XffUniqueMachinesTest(XffTestingBase):
def test_unique_clients_are_counted_independently_with_one_proxy(self):
self.config_proxy(1)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
self.assertEqual(200, self.view(self.request).status_code)
def test_unique_clients_are_counted_independently_with_two_proxies(self):
self.config_proxy(2)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
self.assertEqual(200, self.view(self.request).status_code)
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import namedtuple from collections import namedtuple
from django.conf.urls import patterns, url, include
from django.core import urlresolvers from django.core import urlresolvers
from django.test import TestCase from django.test import TestCase
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.compat import patterns, url, include
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
......
from __future__ import unicode_literals from __future__ import unicode_literals
import sys
import copy import copy
from django.test import TestCase from django.test import TestCase
from rest_framework import status from rest_framework import status
...@@ -11,6 +12,11 @@ from rest_framework.views import APIView ...@@ -11,6 +12,11 @@ from rest_framework.views import APIView
factory = APIRequestFactory() factory = APIRequestFactory()
if sys.version_info[:2] >= (3, 4):
JSON_ERROR = 'JSON parse error - Expecting value:'
else:
JSON_ERROR = 'JSON parse error - No JSON object could be decoded'
class BasicView(APIView): class BasicView(APIView):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
...@@ -48,7 +54,7 @@ def sanitise_json_error(error_dict): ...@@ -48,7 +54,7 @@ def sanitise_json_error(error_dict):
of json. of json.
""" """
ret = copy.copy(error_dict) ret = copy.copy(error_dict)
chop = len('JSON parse error - No JSON object could be decoded') chop = len(JSON_ERROR)
ret['detail'] = ret['detail'][:chop] ret['detail'] = ret['detail'][:chop]
return ret return ret
...@@ -61,7 +67,7 @@ class ClassBasedViewIntegrationTests(TestCase): ...@@ -61,7 +67,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': 'JSON parse error - No JSON object could be decoded' 'detail': JSON_ERROR
} }
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected) self.assertEqual(sanitise_json_error(response.data), expected)
...@@ -76,7 +82,7 @@ class ClassBasedViewIntegrationTests(TestCase): ...@@ -76,7 +82,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data) request = factory.post('/', form_data)
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': 'JSON parse error - No JSON object could be decoded' 'detail': JSON_ERROR
} }
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected) self.assertEqual(sanitise_json_error(response.data), expected)
...@@ -90,7 +96,7 @@ class FunctionBasedViewIntegrationTests(TestCase): ...@@ -90,7 +96,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': 'JSON parse error - No JSON object could be decoded' 'detail': JSON_ERROR
} }
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected) self.assertEqual(sanitise_json_error(response.data), expected)
...@@ -105,7 +111,7 @@ class FunctionBasedViewIntegrationTests(TestCase): ...@@ -105,7 +111,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data) request = factory.post('/', form_data)
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': 'JSON parse error - No JSON object could be decoded' 'detail': JSON_ERROR
} }
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected) self.assertEqual(sanitise_json_error(response.data), expected)
......
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment