Commit 1aa77830 by Eleni Lixourioti

Merge branch 'version-3.1' of github.com:tomchristie/django-rest-framework into oauth_as_package

Conflicts:
	.travis.yml
parents afaa52a3 88008c0a
language: python language: python
python: python: 2.7
- "2.6"
- "2.7"
- "3.2"
- "3.3"
- "3.4"
env: env:
- DJANGO="django==1.7" - TOX_ENV=flake8
- DJANGO="django==1.6.5" - TOX_ENV=py3.4-django1.7
- DJANGO="django==1.5.8" - TOX_ENV=py3.3-django1.7
- DJANGO="django==1.4.13" - TOX_ENV=py3.2-django1.7
- TOX_ENV=py2.7-django1.7
- TOX_ENV=py3.4-django1.6
- TOX_ENV=py3.3-django1.6
- TOX_ENV=py3.2-django1.6
- TOX_ENV=py2.7-django1.6
- TOX_ENV=py2.6-django1.6
- TOX_ENV=py3.4-django1.5
- TOX_ENV=py3.3-django1.5
- TOX_ENV=py3.2-django1.5
- TOX_ENV=py2.7-django1.5
- TOX_ENV=py2.6-django1.5
- TOX_ENV=py2.7-django1.4
- TOX_ENV=py2.6-django1.4
install: install:
- pip install $DJANGO - "pip install tox --download-cache $HOME/.pip-cache"
- pip install defusedxml==0.3
- pip install Pillow==2.3.0
- pip install django-guardian==1.2.3
- pip install pytest-django==2.6.1
- pip install flake8==2.2.2
- "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi"
- "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.7; fi"
- "if [[ ${DJANGO} == 'django==1.7' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- export PYTHONPATH=.
script: script:
- ./runtests.py - tox -e $TOX_ENV
matrix:
exclude:
- python: "2.6"
env: DJANGO="django==1.7"
- python: "3.2"
env: DJANGO="django==1.4.13"
- python: "3.3"
env: DJANGO="django==1.4.13"
- python: "3.4"
env: DJANGO="django==1.4.13"
...@@ -84,7 +84,7 @@ Note that the exception handler will only be called for responses generated by r ...@@ -84,7 +84,7 @@ Note that the exception handler will only be called for responses generated by r
**Signature:** `APIException()` **Signature:** `APIException()`
The **base class** for all exceptions raised inside REST framework. The **base class** for all exceptions raised inside an `APIView` class or `@api_view`.
To provide a custom exception, subclass `APIException` and set the `.status_code` and `.default_detail` properties on the class. To provide a custom exception, subclass `APIException` and set the `.status_code` and `.default_detail` properties on the class.
......
...@@ -274,7 +274,27 @@ Corresponds to `django.db.models.fields.FloatField`. ...@@ -274,7 +274,27 @@ Corresponds to `django.db.models.fields.FloatField`.
## DecimalField ## DecimalField
A decimal representation. A decimal representation, represented in Python by a Decimal instance.
Has two required arguments:
- `max_digits` The maximum number of digits allowed in the number. Note that this number must be greater than or equal to decimal_places.
- `decimal_places` The number of decimal places to store with the number.
For example, to validate numbers up to 999 with a resolution of 2 decimal places, you would use:
serializers.DecimalField(max_digits=5, decimal_places=2)
And to validate numbers up to anything lesss than one billion with a resolution of 10 decimal places:
serializers.DecimalField(max_digits=19, decimal_places=10)
This field also takes an optional argument, `coerce_to_string`. If set to `True` the representation will be output as a string. If set to `False` the representation will be left as a `Decimal` instance and the final representation will be determined by the renderer.
If unset, this will default to the same value as the `COERCE_DECIMAL_TO_STRING` setting, which is `True` unless set otherwise.
**Signature:** `DecimalField(max_digits, decimal_places, coerce_to_string=None)`
Corresponds to `django.db.models.fields.DecimalField`. Corresponds to `django.db.models.fields.DecimalField`.
......
...@@ -193,7 +193,7 @@ filters using `Manufacturer` name. For example: ...@@ -193,7 +193,7 @@ filters using `Manufacturer` name. For example:
class ProductFilter(django_filters.FilterSet): class ProductFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Product model = Product
fields = ['category', 'in_stock', 'manufacturer__name`] fields = ['category', 'in_stock', 'manufacturer__name']
This enables us to make queries like: This enables us to make queries like:
...@@ -211,7 +211,7 @@ This is nice, but it exposes the Django's double underscore convention as part o ...@@ -211,7 +211,7 @@ This is nice, but it exposes the Django's double underscore convention as part o
class Meta: class Meta:
model = Product model = Product
fields = ['category', 'in_stock', 'manufacturer`] fields = ['category', 'in_stock', 'manufacturer']
And now you can execute: And now you can execute:
......
...@@ -212,8 +212,6 @@ Provides a `.list(request, *args, **kwargs)` method, that implements listing a q ...@@ -212,8 +212,6 @@ Provides a `.list(request, *args, **kwargs)` method, that implements listing a q
If the queryset is populated, this returns a `200 OK` response, with a serialized representation of the queryset as the body of the response. The response data may optionally be paginated. If the queryset is populated, this returns a `200 OK` response, with a serialized representation of the queryset as the body of the response. The response data may optionally be paginated.
If the queryset is empty this returns a `200 OK` response, unless the `.allow_empty` attribute on the view is set to `False`, in which case it will return a `404 Not Found`.
## CreateModelMixin ## CreateModelMixin
Provides a `.create(request, *args, **kwargs)` method, that implements creating and saving a new model instance. Provides a `.create(request, *args, **kwargs)` method, that implements creating and saving a new model instance.
......
...@@ -74,37 +74,18 @@ If your API includes views that can serve both regular webpages and API response ...@@ -74,37 +74,18 @@ If your API includes views that can serve both regular webpages and API response
Renders the request data into `JSON`, using utf-8 encoding. Renders the request data into `JSON`, using utf-8 encoding.
Note that non-ascii characters will be rendered using JSON's `\uXXXX` character escape. For example: Note that the default style is to include unicode characters, and render the response using a compact style with no uneccessary whitespace:
{"unicode black star": "\u2605"} {"unicode black star":"★","value":999}
The client may additionally include an `'indent'` media type parameter, in which case the returned `JSON` will be indented. For example `Accept: application/json; indent=4`. The client may additionally include an `'indent'` media type parameter, in which case the returned `JSON` will be indented. For example `Accept: application/json; indent=4`.
{ {
"unicode black star": "\u2605" "unicode black star": "★",
"value": 999
} }
**.media_type**: `application/json` The default JSON encoding style can be altered using the `UNICODE_JSON` and `COMPACT_JSON` settings keys.
**.format**: `'.json'`
**.charset**: `None`
## UnicodeJSONRenderer
Renders the request data into `JSON`, using utf-8 encoding.
Note that non-ascii characters will not be character escaped. For example:
{"unicode black star": "★"}
The client may additionally include an `'indent'` media type parameter, in which case the returned `JSON` will be indented. For example `Accept: application/json; indent=4`.
{
"unicode black star": "★"
}
Both the `JSONRenderer` and `UnicodeJSONRenderer` styles conform to [RFC 4627][rfc4627], and are syntactically valid JSON.
**.media_type**: `application/json` **.media_type**: `application/json`
...@@ -444,6 +425,11 @@ Comma-separated values are a plain-text tabular data format, that can be easily ...@@ -444,6 +425,11 @@ Comma-separated values are a plain-text tabular data format, that can be easily
[djangorestframework-camel-case] provides camel case JSON renderers and parsers for REST framework. This allows serializers to use Python-style underscored field names, but be exposed in the API as Javascript-style camel case field names. It is maintained by [Vitaly Babiy][vbabiy]. [djangorestframework-camel-case] provides camel case JSON renderers and parsers for REST framework. This allows serializers to use Python-style underscored field names, but be exposed in the API as Javascript-style camel case field names. It is maintained by [Vitaly Babiy][vbabiy].
## Pandas (CSV, Excel, PNG)
[Django REST Pandas] provides a serializer and renderers that support additional data processing and output via the [Pandas] DataFrame API. Django REST Pandas includes renderers for Pandas-style CSV files, Excel workbooks (both `.xls` and `.xlsx`), and a number of [other formats]. It is maintained by [S. Andrew Sheppard][sheppard] as part of the [wq Project][wq].
[cite]: https://docs.djangoproject.com/en/dev/ref/template-response/#the-rendering-process [cite]: https://docs.djangoproject.com/en/dev/ref/template-response/#the-rendering-process
[conneg]: content-negotiation.md [conneg]: content-negotiation.md
[browser-accept-headers]: http://www.gethifi.com/blog/browser-rest-http-accept-headers [browser-accept-headers]: http://www.gethifi.com/blog/browser-rest-http-accept-headers
...@@ -467,3 +453,8 @@ Comma-separated values are a plain-text tabular data format, that can be easily ...@@ -467,3 +453,8 @@ Comma-separated values are a plain-text tabular data format, that can be easily
[hzy]: https://github.com/hzy [hzy]: https://github.com/hzy
[drf-ujson-renderer]: https://github.com/gizmag/drf-ujson-renderer [drf-ujson-renderer]: https://github.com/gizmag/drf-ujson-renderer
[djangorestframework-camel-case]: https://github.com/vbabiy/djangorestframework-camel-case [djangorestframework-camel-case]: https://github.com/vbabiy/djangorestframework-camel-case
[Django REST Pandas]: https://github.com/wq/django-rest-pandas
[Pandas]: http://pandas.pydata.org/
[other formats]: https://github.com/wq/django-rest-pandas#supported-formats
[sheppard]: https://github.com/sheppard
[wq]: https://github.com/wq
...@@ -265,7 +265,7 @@ A format string that should be used by default for rendering the output of `Date ...@@ -265,7 +265,7 @@ A format string that should be used by default for rendering the output of `Date
May be any of `None`, `'iso-8601'` or a Python [strftime format][strftime] string. May be any of `None`, `'iso-8601'` or a Python [strftime format][strftime] string.
Default: `None` Default: `'iso-8601'`
#### DATETIME_INPUT_FORMATS #### DATETIME_INPUT_FORMATS
...@@ -281,7 +281,7 @@ A format string that should be used by default for rendering the output of `Date ...@@ -281,7 +281,7 @@ A format string that should be used by default for rendering the output of `Date
May be any of `None`, `'iso-8601'` or a Python [strftime format][strftime] string. May be any of `None`, `'iso-8601'` or a Python [strftime format][strftime] string.
Default: `None` Default: `'iso-8601'`
#### DATE_INPUT_FORMATS #### DATE_INPUT_FORMATS
...@@ -297,7 +297,7 @@ A format string that should be used by default for rendering the output of `Time ...@@ -297,7 +297,7 @@ A format string that should be used by default for rendering the output of `Time
May be any of `None`, `'iso-8601'` or a Python [strftime format][strftime] string. May be any of `None`, `'iso-8601'` or a Python [strftime format][strftime] string.
Default: `None` Default: `'iso-8601'`
#### TIME_INPUT_FORMATS #### TIME_INPUT_FORMATS
...@@ -309,6 +309,46 @@ Default: `['iso-8601']` ...@@ -309,6 +309,46 @@ Default: `['iso-8601']`
--- ---
## Encodings
#### UNICODE_JSON
When set to `True`, JSON responses will allow unicode characters in responses. For example:
{"unicode black star":"★"}
When set to `False`, JSON responses will escape non-ascii characters, like so:
{"unicode black star":"\u2605"}
Both styles conform to [RFC 4627][rfc4627], and are syntactically valid JSON. The unicode style is prefered as being more user-friendly when inspecting API responses.
Default: `True`
#### COMPACT_JSON
When set to `True`, JSON responses will return compact representations, with no spacing after `':'` and `','` characters. For example:
{"is_admin":false,"email":"jane@example"}
When set to `False`, JSON responses will return slightly more verbose representations, like so:
{"is_admin": false, "email": "jane@example"}
The default style is to return minified responses, in line with [Heroku's API design guidelines][heroku-minified-json].
Default: `True`
#### COERCE_DECIMAL_TO_STRING
When returning decimal objects in API representations that do not support a native decimal type, it is normally best to return the value as a string. This avoids the loss of precision that occurs with binary floating point implementations.
When set to `True`, the serializer `DecimalField` class will return strings instead of `Decimal` objects. When set to `False`, serializers will return `Decimal` objects, which the default JSON encoder will return as floats.
Default: `True`
---
## View names and descriptions ## View names and descriptions
**The following settings are used to generate the view names and descriptions, as used in responses to `OPTIONS` requests, and as used in the browsable API.** **The following settings are used to generate the view names and descriptions, as used in responses to `OPTIONS` requests, and as used in the browsable API.**
...@@ -378,4 +418,6 @@ An integer of 0 or more, that may be used to specify the number of application p ...@@ -378,4 +418,6 @@ An integer of 0 or more, that may be used to specify the number of application p
Default: `None` Default: `None`
[cite]: http://www.python.org/dev/peps/pep-0020/ [cite]: http://www.python.org/dev/peps/pep-0020/
[rfc4627]: http://www.ietf.org/rfc/rfc4627.txt
[heroku-minified-json]: https://github.com/interagent/http-api-design#keep-json-minified-in-all-responses
[strftime]: http://docs.python.org/2/library/time.html#time.strftime [strftime]: http://docs.python.org/2/library/time.html#time.strftime
...@@ -178,6 +178,8 @@ To create a custom throttle, override `BaseThrottle` and implement `.allow_reque ...@@ -178,6 +178,8 @@ To create a custom throttle, override `BaseThrottle` and implement `.allow_reque
Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recommended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`. Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recommended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`.
If the `.wait()` method is implemented and the request is throttled, then a `Retry-After` header will be included in the response.
## Example ## Example
The following is an example of a rate throttle, that will randomly throttle 1 in every 10 requests. The following is an example of a rate throttle, that will randomly throttle 1 in every 10 requests.
......
...@@ -192,6 +192,7 @@ General guides to using REST framework. ...@@ -192,6 +192,7 @@ General guides to using REST framework.
* [Browser enhancements][browser-enhancements] * [Browser enhancements][browser-enhancements]
* [The Browsable API][browsableapi] * [The Browsable API][browsableapi]
* [REST, Hypermedia & HATEOAS][rest-hypermedia-hateoas] * [REST, Hypermedia & HATEOAS][rest-hypermedia-hateoas]
* [Third Party Resources][third-party-resources]
* [Contributing to REST framework][contributing] * [Contributing to REST framework][contributing]
* [2.0 Announcement][rest-framework-2-announcement] * [2.0 Announcement][rest-framework-2-announcement]
* [2.2 Announcement][2.2-announcement] * [2.2 Announcement][2.2-announcement]
...@@ -307,6 +308,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -307,6 +308,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
[browsableapi]: topics/browsable-api.md [browsableapi]: topics/browsable-api.md
[rest-hypermedia-hateoas]: topics/rest-hypermedia-hateoas.md [rest-hypermedia-hateoas]: topics/rest-hypermedia-hateoas.md
[contributing]: topics/contributing.md [contributing]: topics/contributing.md
[third-party-resources]: topics/third-party-resources.md
[rest-framework-2-announcement]: topics/rest-framework-2-announcement.md [rest-framework-2-announcement]: topics/rest-framework-2-announcement.md
[2.2-announcement]: topics/2.2-announcement.md [2.2-announcement]: topics/2.2-announcement.md
[2.3-announcement]: topics/2.3-announcement.md [2.3-announcement]: topics/2.3-announcement.md
......
...@@ -117,6 +117,7 @@ a.fusion-poweredby { ...@@ -117,6 +117,7 @@ a.fusion-poweredby {
<li><a href="{{ base_url }}/topics/browser-enhancements{{ suffix }}">Browser enhancements</a></li> <li><a href="{{ base_url }}/topics/browser-enhancements{{ suffix }}">Browser enhancements</a></li>
<li><a href="{{ base_url }}/topics/browsable-api{{ suffix }}">The Browsable API</a></li> <li><a href="{{ base_url }}/topics/browsable-api{{ suffix }}">The Browsable API</a></li>
<li><a href="{{ base_url }}/topics/rest-hypermedia-hateoas{{ suffix }}">REST, Hypermedia & HATEOAS</a></li> <li><a href="{{ base_url }}/topics/rest-hypermedia-hateoas{{ suffix }}">REST, Hypermedia & HATEOAS</a></li>
<li><a href="{{ base_url }}/topics/third-party-resources{{ suffix }}">Third Party Resources</a></li>
<li><a href="{{ base_url }}/topics/contributing{{ suffix }}">Contributing to REST framework</a></li> <li><a href="{{ base_url }}/topics/contributing{{ suffix }}">Contributing to REST framework</a></li>
<li><a href="{{ base_url }}/topics/rest-framework-2-announcement{{ suffix }}">2.0 Announcement</a></li> <li><a href="{{ base_url }}/topics/rest-framework-2-announcement{{ suffix }}">2.0 Announcement</a></li>
<li><a href="{{ base_url }}/topics/2.2-announcement{{ suffix }}">2.2 Announcement</a></li> <li><a href="{{ base_url }}/topics/2.2-announcement{{ suffix }}">2.2 Announcement</a></li>
......
**THIS DOCUMENT IS CURRENTLY A WORK IN PROGRESS**
See the [Version 3.0 GitHub issue](https://github.com/tomchristie/django-rest-framework/pull/1800) for more details.
# REST framework 3.0
**Note incremental nature, discuss upgrading.**
## Motivation
**TODO**
---
## Request objects
#### The `request.data` property.
**TODO**
#### The parser API.
**TODO**
## Serializers
#### Single-step object creation.
**TODO**: Drop `.restore_object()`, use `.create()` and `.update()` which should save the instance.
**TODO**: Drop`.object`, use `.validated_data` or get the instance with `.save()`.
#### Always use `fields`, not `exclude`.
The `exclude` option is no longer available. You should use the more explicit `fields` option instead.
#### The `extra_kwargs` option.
The `read_only_fields` and `write_only_fields` options have been removed and replaced with a more generic `extra_kwargs`.
class MySerializer(serializer.ModelSerializer):
class Meta:
model = MyModel
fields = ('id', 'email', 'notes', 'is_admin')
extra_kwargs = {
'is_admin': {'read_only': True}
}
Alternatively, specify the field explicitly on the serializer class:
class MySerializer(serializer.ModelSerializer):
is_admin = serializers.BooleanField(read_only=True)
class Meta:
model = MyModel
fields = ('id', 'email', 'notes', 'is_admin')
#### Changes to `HyperlinkedModelSerializer`.
The `view_name` and `lookup_field` options have been removed. They are no longer required, as you can use the `extra_kwargs` argument instead:
class MySerializer(serializer.HyperlinkedModelSerializer):
class Meta:
model = MyModel
fields = ('url', 'email', 'notes', 'is_admin')
extra_kwargs = {
'url': {'lookup_field': 'uuid'}
}
Alternatively, specify the field explicitly on the serializer class:
class MySerializer(serializer.HyperlinkedModelSerializer):
url = serializers.HyperlinkedIdentityField(
view_name='mymodel-detail',
lookup_field='uuid'
)
class Meta:
model = MyModel
fields = ('url', 'email', 'notes', 'is_admin')
#### Fields for model methods and properties.
You can now specify field names in the `fields` option that refer to model methods or properties. For example, suppose you have the following model:
class Invitation(models.Model):
created = models.DateTimeField()
to_email = models.EmailField()
message = models.CharField(max_length=1000)
def expiry_date(self):
return self.created + datetime.timedelta(days=30)
You can include `expiry_date` as a field option on a `ModelSerializer` class.
class InvitationSerializer(serializers.ModelSerializer):
class Meta:
model = Invitation
fields = ('to_email', 'message', 'expiry_date')
These fields will be mapped to `serializers.ReadOnlyField()` instances.
>>> serializer = InvitationSerializer()
>>> print repr(serializer)
InvitationSerializer():
to_email = EmailField(max_length=75)
message = CharField(max_length=1000)
expiry_date = ReadOnlyField()
## Serializer fields
#### The `Field` and `ReadOnly` field classes.
**TODO**
#### Coercing output types.
**TODO**
#### The `ListSerializer` class.
**TODO**
#### The `MultipleChoiceField` class.
**TODO**
#### Changes to the custom field API.
**TODO** `to_representation`, `to_internal_value`.
#### Explicit `querysets` required on relational fields.
**TODO**
#### Optional argument to `SerializerMethodField`.
**TODO**
## Generic views
#### Simplification of view logic.
**TODO**
#### Removal of pre/post save hooks.
The following method hooks no longer exist on the new, simplified, generic views: `pre_save`, `post_save`, `pre_delete`, `post_delete`.
If you do need custom behavior, you might choose to instead override the `.save()` method on your serializer class. For example:
def save(self, *args, **kwargs):
instance = super(MySerializer).save(*args, **kwarg)
send_email(instance.to_email, instance.message)
return instance
Alternatively write your view logic exlpicitly, or tie your pre/post save behavior into the model class or model manager.
#### Removal of view attributes.
The `.object` and `.object_list` attributes are no longer set on the view instance. Treating views as mutable object instances that store state during the processing of the view tends to be poor design, and can lead to obscure flow logic.
I would personally recommend that developers treat view instances as immutable objects in their application code.
#### PUT as create.
**TODO**
## API style
There are some improvements in the default style we use in our API responses.
#### Unicode JSON by default.
Unicode JSON is now the default. The `UnicodeJSONRenderer` class no longer exists, and the `UNICODE_JSON` setting has been added. To revert this behavior use the new setting:
REST_FRAMEWORK = {
'UNICODE_JSON': False
}
#### Compact JSON by default.
We now output compact JSON in responses by default. For example, we return:
{"email":"amy@example.com","is_admin":true}
Instead of the following:
{"email": "amy@example.com", "is_admin": true}
The `COMPACT_JSON` setting has been added, and can be used to revert this behavior if needed:
REST_FRAMEWORK = {
'COMPACT_JSON': False
}
#### Throttle headers using `Retry-After`.
The custom `X-Throttle-Wait-Second` header has now been dropped in favor of the standard `Retry-After` header. You can revert this behavior if needed by writing a custom exception handler for your application.
#### Date and time objects as ISO-8859-1 strings in serializer data.
Date and Time objects are now coerced to strings by default in the serializer output. Previously they were returned as `Date`, `Time` and `DateTime` objects, and later coerced to strings by the renderer.
You can modify this behavior globally by settings the existing `DATE_FORMAT`, `DATETIME_FORMAT` and `TIME_FORMAT` settings keys. Setting these values to `None` instead of their default value of `'iso-8859-1'` will result in native objects being returned in serializer data.
REST_FRAMEWORK = {
# Return native `Date` and `Time` objects in `serializer.data`
'DATETIME_FORMAT': None
'DATE_FORMAT': None
'TIME_FORMAT': None
}
You can also modify serializer fields individually, using the `date_format`, `time_format` and `datetime_format` arguments:
# Return `DateTime` instances in `serializer.data`, not strings.
created = serializers.DateTimeField(format=None)
#### Decimals as strings in serializer data.
Decimals are now coerced to strings by default in the serializer output. Previously they were returned as `Decimal` objects, and later coerced to strings by the renderer.
You can modify this behavior globally by using the `COERCE_DECIMAL_TO_STRING` settings key.
REST_FRAMEWORK = {
'COERCE_DECIMAL_TO_STRING': False
}
Or modify it on an individual serializer field, using the `corece_to_string` keyword argument.
# Return `Decimal` instances in `serializer.data`, not strings.
amount = serializers.DecimalField(
max_digits=10,
decimal_places=2,
coerce_to_string=False
)
The default JSON renderer will return float objects for uncoerced `Decimal` instances. This allows you to easily switch between string or float representations for decimals depending on your API design needs.
# Third Party Resources
Django REST Framework has a growing community of developers, packages, and resources.
Check out a grid detailing all the packages and ecosystem around Django REST Framework at [Django Packages](https://www.djangopackages.com/grids/g/django-rest-framework/).
To submit new content, [open an issue](https://github.com/tomchristie/django-rest-framework/issues/new) or [create a pull request](https://github.com/tomchristie/django-rest-framework/).
## Libraries and Extensions
### Authentication
* [djangorestframework-digestauth](https://github.com/juanriaza/django-rest-framework-digestauth) - Provides Digest Access Authentication support.
* [django-oauth-toolkit](https://github.com/evonove/django-oauth-toolkit) - Provides OAuth 2.0 support.
* [doac](https://github.com/Rediker-Software/doac) - Provides OAuth 2.0 support.
* [djangorestframework-jwt](https://github.com/GetBlimp/django-rest-framework-jwt) - Provides JSON Web Token Authentication support.
* [hawkrest](https://github.com/kumar303/hawkrest) - Provides Hawk HTTP Authorization.
* [djangorestframework-httpsignature](https://github.com/etoccalino/django-rest-framework-httpsignature) - Provides an easy to use HTTP Signature Authentication mechanism.
### Permissions
* [drf-any-permissions](https://github.com/kevin-brown/drf-any-permissions) - Provides alternative permission handling.
* [djangorestframework-composed-permissions](https://github.com/niwibe/djangorestframework-composed-permissions) - Provides a simple way to define complex permissions.
* [rest_condition](https://github.com/caxap/rest_condition) - Another extension for building complex permissions in a simple and convenient way.
### Serializers
* [django-rest-framework-mongoengine](https://github.com/umutbozkurt/django-rest-framework-mongoengine) - Serializer class that supports using MongoDB as the storage layer for Django REST framework.
* [djangorestframework-gis](https://github.com/djangonauts/django-rest-framework-gis) - Geographic add-ons
* [djangorestframework-hstore](https://github.com/djangonauts/django-rest-framework-hstore) - Serializer class to support django-hstore DictionaryField model field and its schema-mode feature.
### Serializer fields
* [drf-compound-fields](https://github.com/estebistec/drf-compound-fields) - Provides "compound" serializer fields, such as lists of simple values.
* [django-extra-fields](https://github.com/Hipo/drf-extra-fields) - Provides extra serializer fields.
### Views
* [djangorestframework-bulk](https://github.com/miki725/django-rest-framework-bulk) - Implements generic view mixins as well as some common concrete generic views to allow to apply bulk operations via API requests.
### Routers
* [drf-nested-routers](https://github.com/alanjds/drf-nested-routers) - Provides routers and relationship fields for working with nested resources.
* [wq.db.rest](http://wq.io/docs/about-rest) - Provides an admin-style model registration API with reasonable default URLs and viewsets.
### Parsers
* [djangorestframework-msgpack](https://github.com/juanriaza/django-rest-framework-msgpack) - Provides MessagePack renderer and parser support.
* [djangorestframework-camel-case](https://github.com/vbabiy/djangorestframework-camel-case) - Provides camel case JSON renderers and parsers.
### Renderers
* [djangorestframework-csv](https://github.com/mjumbewu/django-rest-framework-csv) - Provides CSV renderer support.
* [drf_ujson](https://github.com/gizmag/drf-ujson-renderer) - Implements JSON rendering using the UJSON package.
* [Django REST Pandas](https://github.com/wq/django-rest-pandas) - Pandas DataFrame-powered renderers including Excel, CSV, and SVG formats.
### Filtering
* [djangorestframework-chain](https://github.com/philipn/django-rest-framework-chain) - Allows arbitrary chaining of both relations and lookup filters.
### Misc
* [djangorestrelationalhyperlink](https://github.com/fredkingham/django_rest_model_hyperlink_serializers_project) - A hyperlinked serialiser that can can be used to alter relationships via hyperlinks, but otherwise like a hyperlink model serializer.
* [django-rest-swagger](https://github.com/marcgibbons/django-rest-swagger) - An API documentation generator for Swagger UI.
* [django-rest-framework-proxy ](https://github.com/eofs/django-rest-framework-proxy) - Proxy to redirect incoming request to another API server.
* [gaiarestframework](https://github.com/AppsFuel/gaiarestframework) - Utils for django-rest-framewok
* [drf-extensions](https://github.com/chibisov/drf-extensions) - A collection of custom extensions
* [ember-data-django-rest-adapter](https://github.com/toranb/ember-data-django-rest-adapter) - An ember-data adapter
## Tutorials
* [Beginner's Guide to the Django Rest Framework](http://code.tutsplus.com/tutorials/beginners-guide-to-the-django-rest-framework--cms-19786)
* [Getting Started with Django Rest Framework and AngularJS](http://blog.kevinastone.com/getting-started-with-django-rest-framework-and-angularjs.html)
* [End to end web app with Django-Rest-Framework & AngularJS](http://blog.mourafiq.com/post/55034504632/end-to-end-web-app-with-django-rest-framework)
* [Start Your API - django-rest-framework part 1](https://godjango.com/41-start-your-api-django-rest-framework-part-1/)
* [Permissions & Authentication - django-rest-framework part 2](https://godjango.com/43-permissions-authentication-django-rest-framework-part-2/)
* [ViewSets and Routers - django-rest-framework part 3](https://godjango.com/45-viewsets-and-routers-django-rest-framework-part-3/)
* [Django Rest Framework User Endpoint](http://richardtier.com/2014/02/25/django-rest-framework-user-endpoint/)
* [Check credentials using Django Rest Framework](http://richardtier.com/2014/03/06/110/)
## Videos
* [Ember and Django Part 1 (Video)](http://www.neckbeardrepublic.com/screencasts/ember-and-django-part-1)
* [Django Rest Framework Part 1 (Video)](http://www.neckbeardrepublic.com/screencasts/django-rest-framework-part-1)
* [Pyowa July 2013 - Django Rest Framework (Video)](http://www.youtube.com/watch?v=E1ZrehVxpBo)
* [django-rest-framework and angularjs (Video)](http://www.youtube.com/watch?v=Q8FRBGTJ020)
## Articles
* [Web API performance: profiling Django REST framework](http://dabapps.com/blog/api-performance-profiling-django-rest-framework/)
* [API Development with Django and Django REST Framework](https://bnotions.com/api-development-with-django-and-django-rest-framework/)
...@@ -76,6 +76,7 @@ path_list = [ ...@@ -76,6 +76,7 @@ path_list = [
'topics/browser-enhancements.md', 'topics/browser-enhancements.md',
'topics/browsable-api.md', 'topics/browsable-api.md',
'topics/rest-hypermedia-hateoas.md', 'topics/rest-hypermedia-hateoas.md',
'topics/third-party-resources.md',
'topics/contributing.md', 'topics/contributing.md',
'topics/rest-framework-2-announcement.md', 'topics/rest-framework-2-announcement.md',
'topics/2.2-announcement.md', 'topics/2.2-announcement.md',
......
...@@ -8,5 +8,6 @@ flake8==2.2.2 ...@@ -8,5 +8,6 @@ flake8==2.2.2
markdown>=2.1.0 markdown>=2.1.0
PyYAML>=3.10 PyYAML>=3.10
defusedxml>=0.3 defusedxml>=0.3
django-guardian==1.2.4
django-filter>=0.5.4 django-filter>=0.5.4
Pillow==2.3.0 Pillow==2.3.0
# encoding: utf8 # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models, migrations from django.db import models, migrations
...@@ -15,12 +15,11 @@ class Migration(migrations.Migration): ...@@ -15,12 +15,11 @@ class Migration(migrations.Migration):
migrations.CreateModel( migrations.CreateModel(
name='Token', name='Token',
fields=[ fields=[
('key', models.CharField(max_length=40, serialize=False, primary_key=True)), ('key', models.CharField(primary_key=True, serialize=False, max_length=40)),
('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, to_field='id')),
('created', models.DateTimeField(auto_now_add=True)), ('created', models.DateTimeField(auto_now_add=True)),
('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token')),
], ],
options={ options={
'abstract': False,
}, },
bases=(models.Model,), bases=(models.Model,),
), ),
......
...@@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer): ...@@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
attrs['user'] = user
return attrs
else: else:
msg = _('Unable to login with provided credentials.') msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
else: else:
msg = _('Must include "username" and "password"') msg = _('Must include "username" and "password"')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
attrs['user'] = user
return attrs
...@@ -18,7 +18,8 @@ class ObtainAuthToken(APIView): ...@@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):
def post(self, request): def post(self, request):
serializer = self.serializer_class(data=request.DATA) serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid(): if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user']) user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key}) return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
......
...@@ -39,6 +39,17 @@ except ImportError: ...@@ -39,6 +39,17 @@ except ImportError:
django_filters = None django_filters = None
if django.VERSION >= (1, 6):
def clean_manytomany_helptext(text):
return text
else:
# Up to version 1.5 many to many fields automatically suffix
# the `help_text` attribute with hardcoded text.
def clean_manytomany_helptext(text):
if text.endswith(' Hold down "Control", or "Command" on a Mac, to select more than one.'):
text = text[:-69]
return text
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
# Fixes (#1712). We keep the try/except for the test suite. # Fixes (#1712). We keep the try/except for the test suite.
guardian = None guardian = None
......
...@@ -10,7 +10,6 @@ from __future__ import unicode_literals ...@@ -10,7 +10,6 @@ from __future__ import unicode_literals
from django.utils import six from django.utils import six
from rest_framework.views import APIView from rest_framework.views import APIView
import types import types
import warnings
def api_view(http_method_names): def api_view(http_method_names):
...@@ -130,37 +129,3 @@ def list_route(methods=['get'], **kwargs): ...@@ -130,37 +129,3 @@ def list_route(methods=['get'], **kwargs):
func.kwargs = kwargs func.kwargs = kwargs
return func return func
return decorator return decorator
# These are now pending deprecation, in favor of `detail_route` and `list_route`.
def link(**kwargs):
"""
Used to mark a method on a ViewSet that should be routed for detail GET requests.
"""
msg = 'link is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = ['get']
func.detail = True
func.kwargs = kwargs
return func
return decorator
def action(methods=['post'], **kwargs):
"""
Used to mark a method on a ViewSet that should be routed for detail POST requests.
"""
msg = 'action is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = methods
func.detail = True
func.kwargs = kwargs
return func
return decorator
...@@ -15,7 +15,7 @@ class APIException(Exception): ...@@ -15,7 +15,7 @@ class APIException(Exception):
Subclasses should provide `.status_code` and `.default_detail` properties. Subclasses should provide `.status_code` and `.default_detail` properties.
""" """
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = '' default_detail = 'A server error occured'
def __init__(self, detail=None): def __init__(self, detail=None):
self.detail = detail or self.default_detail self.detail = detail or self.default_detail
...@@ -54,7 +54,7 @@ class MethodNotAllowed(APIException): ...@@ -54,7 +54,7 @@ class MethodNotAllowed(APIException):
class NotAcceptable(APIException): class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE status_code = status.HTTP_406_NOT_ACCEPTABLE
default_detail = "Could not satisfy the request's Accept header" default_detail = "Could not satisfy the request Accept header"
def __init__(self, detail=None, available_renderers=None): def __init__(self, detail=None, available_renderers=None):
self.detail = detail or self.default_detail self.detail = detail or self.default_detail
......
""" from django.conf import settings
Serializer fields perform validation on incoming data.
They are very similar to Django's form fields.
"""
from __future__ import unicode_literals
import copy
import datetime
import inspect
import re
import warnings
from decimal import Decimal, DecimalException
from django import forms
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.conf import settings from django.utils import timezone
from django.db.models.fields import BLANK_CHOICE_DASH from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.http import QueryDict
from django.forms import widgets
from django.utils import six, timezone
from django.utils.encoding import is_protected_type from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from rest_framework import ISO_8601 from rest_framework import ISO_8601
from rest_framework.compat import ( from rest_framework.compat import smart_text
BytesIO, smart_text,
force_text, is_non_str_iterable
)
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
import datetime
import decimal
import inspect
import warnings
class empty:
"""
This class is used to represent no data being provided for a given input
or output value.
It is required because `None` may be a valid input or output value.
"""
pass
def is_simple_callable(obj): def is_simple_callable(obj):
...@@ -47,597 +41,487 @@ def is_simple_callable(obj): ...@@ -47,597 +41,487 @@ def is_simple_callable(obj):
return len_args <= len_defaults return len_args <= len_defaults
def get_component(obj, attr_name): def get_attribute(instance, attrs):
""" """
Given an object, and an attribute name, Similar to Python's built in `getattr(instance, attr)`,
return that attribute on the object. but takes a list of nested attributes, instead of a single attribute.
"""
if isinstance(obj, dict):
val = obj.get(attr_name)
else:
val = getattr(obj, attr_name)
if is_simple_callable(val): Also accepts either attribute lookup on objects or dictionary lookups.
return val() """
return val for attr in attrs:
try:
instance = getattr(instance, attr)
def readable_datetime_formats(formats): except AttributeError as exc:
format = ', '.join(formats).replace( try:
ISO_8601, return instance[attr]
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' except (KeyError, TypeError):
) raise exc
return humanize_strptime(format) return instance
def readable_date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
return humanize_strptime(format)
def readable_time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
return humanize_strptime(format)
def humanize_strptime(format_string):
# Note that we're missing some of the locale specific mappings that
# don't really make sense.
mapping = {
"%Y": "YYYY",
"%y": "YY",
"%m": "MM",
"%b": "[Jan-Dec]",
"%B": "[January-December]",
"%d": "DD",
"%H": "hh",
"%I": "hh", # Requires '%p' to differentiate from '%H'.
"%M": "mm",
"%S": "ss",
"%f": "uuuuuu",
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]"
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)
return format_string
def strip_multiple_choice_msg(help_text): def set_value(dictionary, keys, value):
""" """
Remove the 'Hold down "control" ...' message that is Django enforces in Similar to Python's built in `dictionary[key] = value`,
select multiple fields on ModelForms. (Required for 1.5 and earlier) but takes a list of nested keys instead of a single key.
See https://code.djangoproject.com/ticket/9321 set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
""" """
multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') if not keys:
multiple_choice_msg = force_text(multiple_choice_msg) dictionary.update(value)
return
return help_text.replace(multiple_choice_msg, '') for key in keys[:-1]:
if key not in dictionary:
dictionary[key] = {}
dictionary = dictionary[key]
dictionary[keys[-1]] = value
class Field(object):
read_only = True
creation_counter = 0
empty = ''
type_name = None
partial = False
use_files = False
form_field_class = forms.CharField
type_label = 'field'
widget = None
def __init__(self, source=None, label=None, help_text=None):
self.parent = None
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
self.source = source class SkipField(Exception):
pass
if label is not None:
self.label = smart_text(label)
else:
self.label = None
if help_text is not None: NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
self.help_text = strip_multiple_choice_msg(smart_text(help_text)) NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
else: NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
self.help_text = None NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.'
)
self._errors = []
self._value = None
self._name = None
@property class Field(object):
def errors(self): _creation_counter = 0
return self._errors
def widget_html(self): default_error_messages = {
if not self.widget: 'required': _('This field is required.')
return '' }
default_validators = []
attrs = {} def __init__(self, read_only=False, write_only=False,
if 'id' not in self.widget.attrs: required=None, default=empty, initial=None, source=None,
attrs['id'] = self._name label=None, help_text=None, style=None,
error_messages=None, validators=[]):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
return self.widget.render(self._name, self._value, attrs=attrs) # If `required` is unset, then use `True` unless a default is provided.
if required is None:
required = default is empty and not read_only
def label_tag(self): # Some combinations of keyword arguments do not make sense.
return '<label for="%s">%s:</label>' % (self._name, self.label) assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
def initialize(self, parent, field_name): self.read_only = read_only
""" self.write_only = write_only
Called to set up a field prior to field_to_native or field_from_native. self.required = required
self.default = default
self.source = source
self.initial = initial
self.label = label
self.help_text = help_text
self.style = {} if style is None else style
self.validators = validators or self.default_validators[:]
parent - The parent serializer. # Collect default error message from self and parent classes
field_name - The name of the field being initialized. messages = {}
""" for cls in reversed(self.__class__.__mro__):
self.parent = parent messages.update(getattr(cls, 'default_error_messages', {}))
self.root = parent.root or parent messages.update(error_messages or {})
self.context = self.root.context self.error_messages = messages
self.partial = self.root.partial
if self.partial:
self.required = False
def field_from_native(self, data, files, field_name, into): def __new__(cls, *args, **kwargs):
""" """
Given a dictionary and a field name, updates the dictionary `into`, When a field is instantiated, we store the arguments that were used,
with the field and it's deserialized value. so that we can present a helpful representation of the object.
""" """
return instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def field_to_native(self, obj, field_name): def bind(self, field_name, parent, root):
""" """
Given an object and a field name, returns the value that should be Setup the context for the field instance.
serialized for that field.
""" """
if obj is None: self.field_name = field_name
return self.empty self.parent = parent
self.root = root
if self.source == '*': self.context = parent.context
return self.to_native(obj)
source = self.source or field_name # `self.label` should deafult to being based on the field name.
value = obj if self.label is None:
self.label = field_name.replace('_', ' ').capitalize()
for component in source.split('.'): # self.source should default to being the same as the field name.
value = get_component(value, component) if self.source is None:
if value is None: self.source = field_name
break
return self.to_native(value) # self.source_attrs is a list of attributes that need to be looked up
# when serializing the instance, or populating the validated data.
if self.source == '*':
self.source_attrs = []
else:
self.source_attrs = self.source.split('.')
def to_native(self, value): def get_initial(self):
""" """
Converts the field's value into it's simple representation. Return a value to use when the field is being returned as a primative
value, without any object instance.
""" """
if is_simple_callable(value): return self.initial
value = value()
if is_protected_type(value): def get_value(self, dictionary):
return value
elif (is_non_str_iterable(value) and
not isinstance(value, (dict, six.string_types))):
return [self.to_native(item) for item in value]
elif isinstance(value, dict):
# Make sure we preserve field ordering, if it exists
ret = SortedDict()
for key, val in value.items():
ret[key] = self.to_native(val)
return ret
return force_text(value)
def attributes(self):
""" """
Returns a dictionary of attributes to be used when serializing to xml. Given the *incoming* primative data, return the value for this field
that should be validated and transformed to a native value.
""" """
if self.type_name: return dictionary.get(self.field_name, empty)
return {'type': self.type_name}
return {} def get_attribute(self, instance):
def metadata(self):
metadata = SortedDict()
metadata['type'] = self.type_label
metadata['required'] = getattr(self, 'required', False)
optional_attrs = ['read_only', 'label', 'help_text',
'min_length', 'max_length']
for attr in optional_attrs:
value = getattr(self, attr, None)
if value is not None and value != '':
metadata[attr] = force_text(value, strings_only=True)
return metadata
class WritableField(Field):
""" """
Base for read/write fields. Given the *outgoing* object instance, return the value for this field
that should be returned as a primative value.
""" """
write_only = False return get_attribute(instance, self.source_attrs)
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
widget = widgets.TextInput
default = None
def __init__(self, source=None, label=None, help_text=None, def get_default(self):
read_only=False, write_only=False, required=None, """
validators=[], error_messages=None, widget=None, Return the default value to use when validating data if no input
default=None, blank=None): is provided for this field.
super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
self.read_only = read_only
self.write_only = write_only
assert not (read_only and write_only), "Cannot set read_only=True and write_only=True"
if required is None:
self.required = not(read_only)
else:
assert not (read_only and required), "Cannot set required=True and read_only=True"
self.required = required
messages = {}
for c in reversed(self.__class__.__mro__):
messages.update(getattr(c, 'default_error_messages', {}))
messages.update(error_messages or {})
self.error_messages = messages
self.validators = self.default_validators + validators
self.default = default if default is not None else self.default
# Widgets are only used for HTML forms. If a default has not been set for this field then this will simply
widget = widget or self.widget return `empty`, indicating that no value should be set in the
if isinstance(widget, type): validated data for this field.
widget = widget() """
self.widget = widget if self.default is empty:
raise SkipField()
return self.default
def __deepcopy__(self, memo): def run_validation(self, data=empty):
result = copy.copy(self) """
memo[id(self)] = result Validate a simple representation and return the internal value.
result.validators = self.validators[:]
return result
def get_default_value(self): The provided data may be `empty` if no representation was included.
if is_simple_callable(self.default): May return `empty` if the field should not be included in the
return self.default() validated data.
return self.default """
if data is empty:
if self.required:
self.fail('required')
return self.get_default()
def validate(self, value): value = self.to_internal_value(data)
if value in validators.EMPTY_VALUES and self.required: self.run_validators(value)
raise ValidationError(self.error_messages['required']) return value
def run_validators(self, value): def run_validators(self, value):
if value in validators.EMPTY_VALUES: if value in (None, '', [], (), {}):
return return
errors = [] errors = []
for v in self.validators: for validator in self.validators:
try: try:
v(value) validator(value)
except ValidationError as e: except ValidationError as exc:
if hasattr(e, 'code') and e.code in self.error_messages: errors.extend(exc.messages)
message = self.error_messages[e.code]
if e.params:
message = message % e.params
errors.append(message)
else:
errors.extend(e.messages)
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
def field_to_native(self, obj, field_name): def to_internal_value(self, data):
if self.write_only:
return None
return super(WritableField, self).field_to_native(obj, field_name)
def field_from_native(self, data, files, field_name, into):
""" """
Given a dictionary and a field name, updates the dictionary `into`, Transform the *incoming* primative data into a native value.
with the field and it's deserialized value.
""" """
if self.read_only: raise NotImplementedError('to_internal_value() must be implemented.')
return
try:
data = data or {}
if self.use_files:
files = files or {}
try:
native = files[field_name]
except KeyError:
native = data[field_name]
else:
native = data[field_name]
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
native = self.get_default_value()
else:
if self.required:
raise ValidationError(self.error_messages['required'])
return
value = self.from_native(native)
if self.source == '*':
if value:
into.update(value)
else:
self.validate(value)
self.run_validators(value)
into[self.source or field_name] = value
def from_native(self, value): def to_representation(self, value):
""" """
Reverts a simple representation back to the field's value. Transform the *outgoing* native value into primative data.
""" """
return value raise NotImplementedError('to_representation() must be implemented.')
class ModelField(WritableField): def fail(self, key, **kwargs):
""" """
A generic field that can be used against an arbitrary model field. A helper method that simply raises a validation error.
""" """
def __init__(self, *args, **kwargs):
try: try:
self.model_field = kwargs.pop('model_field') msg = self.error_messages[key]
except KeyError: except KeyError:
raise ValueError("ModelField requires 'model_field' kwarg") class_name = self.__class__.__name__
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
self.min_length = kwargs.pop('min_length', raise AssertionError(msg)
getattr(self.model_field, 'min_length', None)) raise ValidationError(msg.format(**kwargs))
self.max_length = kwargs.pop('max_length',
getattr(self.model_field, 'max_length', None))
self.min_value = kwargs.pop('min_value',
getattr(self.model_field, 'min_value', None))
self.max_value = kwargs.pop('max_value',
getattr(self.model_field, 'max_value', None))
super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None:
self.validators.append(validators.MinLengthValidator(self.min_length))
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
if self.min_value is not None:
self.validators.append(validators.MinValueValidator(self.min_value))
if self.max_value is not None:
self.validators.append(validators.MaxValueValidator(self.max_value))
def from_native(self, value): def __repr__(self):
rel = getattr(self.model_field, "rel", None) return representation.field_repr(self)
if rel is not None:
return rel.to._meta.get_field(rel.field_name).to_python(value)
else:
return self.model_field.to_python(value)
def field_to_native(self, obj, field_name):
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):
return value
return self.model_field.value_to_string(obj)
def attributes(self):
return {
"type": self.model_field.get_internal_type()
}
# Typed Fields # Boolean types...
class BooleanField(WritableField): class BooleanField(Field):
type_name = 'BooleanField'
type_label = 'boolean'
form_field_class = forms.BooleanField
widget = widgets.CheckboxInput
default_error_messages = { default_error_messages = {
'invalid': _("'%s' value must be either True or False."), 'invalid': _('`{input}` is not a valid boolean.')
} }
empty = False TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
def field_from_native(self, data, files, field_name, into):
# HTML checkboxes do not explicitly represent unchecked as `False` def get_value(self, dictionary):
# we deal with that here... if html.is_html_input(dictionary):
if isinstance(data, QueryDict) and self.default is None: # HTML forms do not send a `False` value on an empty checkbox,
self.default = False # so we override the default empty value to be False.
return dictionary.get(self.field_name, False)
return super(BooleanField, self).field_from_native( return dictionary.get(self.field_name, empty)
data, files, field_name, into
) def to_internal_value(self, data):
if data in self.TRUE_VALUES:
return True
elif data in self.FALSE_VALUES:
return False
self.fail('invalid', input=data)
def from_native(self, value): def to_representation(self, value):
if value in ('true', 't', 'True', '1'): if value is None:
return None
if value in self.TRUE_VALUES:
return True return True
if value in ('false', 'f', 'False', '0'): elif value in self.FALSE_VALUES:
return False return False
return bool(value) return bool(value)
class CharField(WritableField): # String types...
type_name = 'CharField'
type_label = 'string'
form_field_class = forms.CharField
def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs): class CharField(Field):
self.max_length, self.min_length = max_length, min_length default_error_messages = {
self.allow_none = allow_none 'blank': _('This field may not be blank.')
super(CharField, self).__init__(*args, **kwargs) }
if min_length is not None:
self.validators.append(validators.MinLengthValidator(min_length))
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value): def __init__(self, **kwargs):
if isinstance(value, six.string_types): self.allow_blank = kwargs.pop('allow_blank', False)
return value self.max_length = kwargs.pop('max_length', None)
self.min_length = kwargs.pop('min_length', None)
super(CharField, self).__init__(**kwargs)
def to_internal_value(self, data):
if data == '' and not self.allow_blank:
self.fail('blank')
if data is None:
return None
return str(data)
def to_representation(self, value):
if value is None: if value is None:
if not self.allow_none:
return ''
else:
# Return None explicitly because smart_text(None) == 'None'. See #1834 for details
return None return None
return str(value)
return smart_text(value)
class EmailField(CharField):
default_error_messages = {
'invalid': _('Enter a valid email address.')
}
default_validators = [validators.validate_email]
class URLField(CharField): def to_internal_value(self, data):
type_name = 'URLField' if data == '' and not self.allow_blank:
type_label = 'url' self.fail('blank')
if data is None:
return None
return str(data).strip()
def __init__(self, **kwargs): def to_representation(self, value):
if 'validators' not in kwargs: if value is None:
kwargs['validators'] = [validators.URLValidator()] return None
super(URLField, self).__init__(**kwargs) return str(value).strip()
class SlugField(CharField): class RegexField(CharField):
type_name = 'SlugField' def __init__(self, regex, **kwargs):
type_label = 'slug' kwargs['validators'] = (
form_field_class = forms.SlugField [validators.RegexValidator(regex)] +
kwargs.get('validators', [])
)
super(RegexField, self).__init__(**kwargs)
class SlugField(CharField):
default_error_messages = { default_error_messages = {
'invalid': _("Enter a valid 'slug' consisting of letters, numbers," 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
" underscores or hyphens."),
} }
default_validators = [validators.validate_slug] default_validators = [validators.validate_slug]
def __init__(self, *args, **kwargs):
super(SlugField, self).__init__(*args, **kwargs)
class URLField(CharField):
class ChoiceField(WritableField):
type_name = 'ChoiceField'
type_label = 'choice'
form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = { default_error_messages = {
'invalid_choice': _('Select a valid choice. %(value)s is not one of ' 'invalid': _("Enter a valid URL.")
'the available choices.'),
} }
default_validators = [validators.URLValidator()]
def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
if blank_display_value is None:
blank_choice = BLANK_CHOICE_DASH
else:
blank_choice = [('', blank_display_value)]
self.choices = blank_choice + self.choices
def _get_choices(self): # Number types...
return self._choices
def _set_choices(self, value): class IntegerField(Field):
# Setting choices also sets the choices on the widget. default_error_messages = {
# choices can be any iterable, but we call list() on it because 'invalid': _('A valid integer is required.')
# it will be consumed more than once. }
self._choices = self.widget.choices = list(value)
choices = property(_get_choices, _set_choices) def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def metadata(self): def to_internal_value(self, data):
data = super(ChoiceField, self).metadata() try:
data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices] data = int(str(data))
except (ValueError, TypeError):
self.fail('invalid')
return data return data
def validate(self, value): def to_representation(self, value):
""" if value is None:
Validates that the input is in self.choices. return None
""" return int(value)
super(ChoiceField, self).validate(value)
if value and not self.valid_value(value):
raise ValidationError(self.error_messages['invalid_choice'] % {'value': value})
def valid_value(self, value):
"""
Check to see if the provided value is a valid choice.
"""
for k, v in self.choices:
if isinstance(v, (list, tuple)):
# This is an optgroup, so look inside the group for options
for k2, v2 in v:
if value == smart_text(k2):
return True
else:
if value == smart_text(k) or value == k:
return True
return False
def from_native(self, value): class FloatField(Field):
value = super(ChoiceField, self).from_native(value) default_error_messages = {
if value == self.empty or value in validators.EMPTY_VALUES: 'invalid': _("'%s' value must be a float."),
return self.empty }
return value
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(FloatField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
class EmailField(CharField): def to_internal_value(self, value):
type_name = 'EmailField' if value is None:
type_label = 'email' return None
form_field_class = forms.EmailField return float(value)
def to_representation(self, value):
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
self.fail('invalid', value=value)
class DecimalField(Field):
default_error_messages = { default_error_messages = {
'invalid': _('Enter a valid email address.'), 'invalid': _('Enter a number.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
} }
default_validators = [validators.validate_email]
def from_native(self, value): coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING
ret = super(EmailField, self).from_native(value)
if ret is None: def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs):
self.max_digits = max_digits
self.decimal_places = decimal_places
self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_internal_value(self, value):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
if value in (None, ''):
return None return None
return ret.strip()
value = smart_text(value).strip()
try:
value = decimal.Decimal(value)
except decimal.DecimalException:
self.fail('invalid')
class RegexField(CharField): # Check for NaN. It is the only value that isn't equal to itself,
type_name = 'RegexField' # so we can use this to identify NaN values.
type_label = 'regex' if value != value:
form_field_class = forms.RegexField self.fail('invalid')
def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): # Check for infinity and negative infinity.
super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
self.regex = regex self.fail('invalid')
def _get_regex(self): sign, digittuple, exponent = value.as_tuple()
return self._regex decimals = abs(exponent)
# digittuple doesn't include any leading zeros.
digits = len(digittuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
self.fail('max_digits', max_digits=self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places:
self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
return value
def _set_regex(self, regex): def to_representation(self, value):
if isinstance(regex, six.string_types): if isinstance(value, decimal.Decimal):
regex = re.compile(regex) context = decimal.getcontext().copy()
self._regex = regex context.prec = self.max_digits
if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: quantized = value.quantize(
self.validators.remove(self._regex_validator) decimal.Decimal('.1') ** self.decimal_places,
self._regex_validator = validators.RegexValidator(regex=regex) context=context
self.validators.append(self._regex_validator) )
if not self.coerce_to_string:
return quantized
return '{0:f}'.format(quantized)
regex = property(_get_regex, _set_regex) if not self.coerce_to_string:
return value
return '%.*f' % (self.max_decimal_places, value)
class DateField(WritableField): # Date & time fields...
type_name = 'DateField'
type_label = 'date'
widget = widgets.DateInput
form_field_class = forms.DateField
class DateField(Field):
default_error_messages = { default_error_messages = {
'invalid': _("Date has wrong format. Use one of these formats instead: %s"), 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
} }
empty = None
input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT format = api_settings.DATE_FORMAT
input_formats = api_settings.DATE_INPUT_FORMATS
def __init__(self, input_formats=None, format=None, *args, **kwargs): def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format self.format = format if format is not None else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
super(DateField, self).__init__(*args, **kwargs) super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value): def to_internal_value(self, value):
if value in validators.EMPTY_VALUES: if value in (None, ''):
return None return None
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
...@@ -647,6 +531,7 @@ class DateField(WritableField): ...@@ -647,6 +531,7 @@ class DateField(WritableField):
default_timezone = timezone.get_default_timezone() default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone) value = timezone.make_naive(value, default_timezone)
return value.date() return value.date()
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
return value return value
...@@ -667,10 +552,10 @@ class DateField(WritableField): ...@@ -667,10 +552,10 @@ class DateField(WritableField):
else: else:
return parsed.date() return parsed.date()
msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) humanized_format = humanize_datetime.date_formats(self.input_formats)
raise ValidationError(msg) self.fail('invalid', format=humanized_format)
def to_native(self, value): def to_representation(self, value):
if value is None or self.format is None: if value is None or self.format is None:
return value return value
...@@ -682,30 +567,25 @@ class DateField(WritableField): ...@@ -682,30 +567,25 @@ class DateField(WritableField):
return value.strftime(self.format) return value.strftime(self.format)
class DateTimeField(WritableField): class DateTimeField(Field):
type_name = 'DateTimeField'
type_label = 'datetime'
widget = widgets.DateTimeInput
form_field_class = forms.DateTimeField
default_error_messages = { default_error_messages = {
'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
} }
empty = None
input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT format = api_settings.DATETIME_FORMAT
input_formats = api_settings.DATETIME_INPUT_FORMATS
def __init__(self, input_formats=None, format=None, *args, **kwargs): def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format self.format = format if format is not None else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
super(DateTimeField, self).__init__(*args, **kwargs) super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value): def to_internal_value(self, value):
if value in validators.EMPTY_VALUES: if value in (None, ''):
return None return None
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
return value return value
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day) value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ: if settings.USE_TZ:
...@@ -737,10 +617,10 @@ class DateTimeField(WritableField): ...@@ -737,10 +617,10 @@ class DateTimeField(WritableField):
else: else:
return parsed return parsed
msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) humanized_format = humanize_datetime.datetime_formats(self.input_formats)
raise ValidationError(msg) self.fail('invalid', format=humanized_format)
def to_native(self, value): def to_representation(self, value):
if value is None or self.format is None: if value is None or self.format is None:
return value return value
...@@ -752,26 +632,20 @@ class DateTimeField(WritableField): ...@@ -752,26 +632,20 @@ class DateTimeField(WritableField):
return value.strftime(self.format) return value.strftime(self.format)
class TimeField(WritableField): class TimeField(Field):
type_name = 'TimeField'
type_label = 'time'
widget = widgets.TimeInput
form_field_class = forms.TimeField
default_error_messages = { default_error_messages = {
'invalid': _("Time has wrong format. Use one of these formats instead: %s"), 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'),
} }
empty = None
input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT format = api_settings.TIME_FORMAT
input_formats = api_settings.TIME_INPUT_FORMATS
def __init__(self, input_formats=None, format=None, *args, **kwargs): def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format self.format = format if format is not None else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
super(TimeField, self).__init__(*args, **kwargs) super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value): def from_native(self, value):
if value in validators.EMPTY_VALUES: if value in (None, ''):
return None return None
if isinstance(value, datetime.time): if isinstance(value, datetime.time):
...@@ -794,10 +668,10 @@ class TimeField(WritableField): ...@@ -794,10 +668,10 @@ class TimeField(WritableField):
else: else:
return parsed.time() return parsed.time()
msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) humanized_format = humanize_datetime.time_formats(self.input_formats)
raise ValidationError(msg) self.fail('invalid', format=humanized_format)
def to_native(self, value): def to_representation(self, value):
if value is None or self.format is None: if value is None or self.format is None:
return value return value
...@@ -809,234 +683,147 @@ class TimeField(WritableField): ...@@ -809,234 +683,147 @@ class TimeField(WritableField):
return value.strftime(self.format) return value.strftime(self.format)
class IntegerField(WritableField): # Choice types...
type_name = 'IntegerField'
type_label = 'integer'
form_field_class = forms.IntegerField
empty = 0
class ChoiceField(Field):
default_error_messages = { default_error_messages = {
'invalid': _('Enter a whole number.'), 'invalid_choice': _('`{input}` is not a valid choice.')
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
} }
def __init__(self, max_value=None, min_value=None, *args, **kwargs): def __init__(self, choices, **kwargs):
self.max_value, self.min_value = max_value, min_value # Allow either single or paired choices style:
super(IntegerField, self).__init__(*args, **kwargs) # choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
pairs = [
isinstance(item, (list, tuple)) and len(item) == 2
for item in choices
]
if all(pairs):
self.choices = dict([(key, display_value) for key, display_value in choices])
else:
self.choices = dict([(item, item) for item in choices])
if max_value is not None: # Map the string representation of choices to the underlying value.
self.validators.append(validators.MaxValueValidator(max_value)) # Allows us to deal with eg. integer choices while supporting either
if min_value is not None: # integer or string input, but still get the correct datatype out.
self.validators.append(validators.MinValueValidator(min_value)) self.choice_strings_to_values = dict([
(str(key), key) for key in self.choices.keys()
])
def from_native(self, value): super(ChoiceField, self).__init__(**kwargs)
if value in validators.EMPTY_VALUES:
return None
def to_internal_value(self, data):
try: try:
value = int(str(value)) return self.choice_strings_to_values[str(data)]
except (ValueError, TypeError): except KeyError:
raise ValidationError(self.error_messages['invalid']) self.fail('invalid_choice', input=data)
return value
def to_representation(self, value):
return value
class FloatField(WritableField):
type_name = 'FloatField'
type_label = 'float'
form_field_class = forms.FloatField
empty = 0
class MultipleChoiceField(ChoiceField):
default_error_messages = { default_error_messages = {
'invalid': _("'%s' value must be a float."), 'invalid_choice': _('`{input}` is not a valid choice.'),
'not_a_list': _('Expected a list of items but got type `{input_type}`')
} }
def from_native(self, value): def to_internal_value(self, data):
if value in validators.EMPTY_VALUES: if not hasattr(data, '__iter__'):
return None self.fail('not_a_list', input_type=type(data).__name__)
return set([
super(MultipleChoiceField, self).to_internal_value(item)
for item in data
])
try: def to_representation(self, value):
return float(value) return value
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
class DecimalField(WritableField): # File types...
type_name = 'DecimalField'
type_label = 'decimal'
form_field_class = forms.DecimalField
empty = Decimal('0')
default_error_messages = { class FileField(Field):
'invalid': _('Enter a number.'), pass # TODO
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
'max_digits': _('Ensure that there are no more than %s digits in total.'),
'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
}
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
self.max_value, self.min_value = max_value, min_value
self.max_digits, self.decimal_places = max_digits, decimal_places
super(DecimalField, self).__init__(*args, **kwargs)
if max_value is not None: class ImageField(Field):
self.validators.append(validators.MaxValueValidator(max_value)) pass # TODO
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
if value in validators.EMPTY_VALUES:
return None
value = smart_text(value).strip()
try:
value = Decimal(value)
except DecimalException:
raise ValidationError(self.error_messages['invalid'])
return value
def validate(self, value):
super(DecimalField, self).validate(value)
if value in validators.EMPTY_VALUES:
return
# Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
# since it is never equal to itself. However, NaN is the only value that
# isn't equal to itself, so we can use this to identify NaN
if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
raise ValidationError(self.error_messages['invalid'])
sign, digittuple, exponent = value.as_tuple()
decimals = abs(exponent)
# digittuple doesn't include any leading zeros.
digits = len(digittuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places:
raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
return value
# Advanced field types...
class FileField(WritableField): class ReadOnlyField(Field):
use_files = True """
type_name = 'FileField' A read-only field that simply returns the field value.
type_label = 'file upload'
form_field_class = forms.FileField
widget = widgets.FileInput
default_error_messages = { If the field is a method with no parameters, the method will be called
'invalid': _("No file was submitted. Check the encoding type on the form."), and it's return value used as the representation.
'missing': _("No file was submitted."),
'empty': _("The submitted file is empty."),
'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
}
def __init__(self, *args, **kwargs): For example, the following would call `get_expiry_date()` on the object:
self.max_length = kwargs.pop('max_length', None)
self.allow_empty_file = kwargs.pop('allow_empty_file', False)
super(FileField, self).__init__(*args, **kwargs)
def from_native(self, data): class ExampleSerializer(self):
if data in validators.EMPTY_VALUES: expiry_date = ReadOnlyField(source='get_expiry_date')
return None """
# UploadedFile objects should have name and size attributes. def __init__(self, **kwargs):
try: kwargs['read_only'] = True
file_name = data.name super(ReadOnlyField, self).__init__(**kwargs)
file_size = data.size
except AttributeError:
raise ValidationError(self.error_messages['invalid'])
if self.max_length is not None and len(file_name) > self.max_length:
error_values = {'max': self.max_length, 'length': len(file_name)}
raise ValidationError(self.error_messages['max_length'] % error_values)
if not file_name:
raise ValidationError(self.error_messages['invalid'])
if not self.allow_empty_file and not file_size:
raise ValidationError(self.error_messages['empty'])
return data def to_representation(self, value):
if is_simple_callable(value):
return value()
return value
def to_native(self, value):
return value.name
class SerializerMethodField(Field):
"""
A read-only field that get its representation from calling a method on the
parent serializer class. The method called will be of the form
"get_{field_name}", and should take a single argument, which is the
object being serialized.
class ImageField(FileField): For example:
use_files = True
type_name = 'ImageField'
type_label = 'image upload'
form_field_class = forms.ImageField
default_error_messages = { class ExampleSerializer(self):
'invalid_image': _("Upload a valid image. The file you uploaded was " extra_info = SerializerMethodField()
"either not an image or a corrupted image."),
}
def from_native(self, data): def get_extra_info(self, obj):
return ... # Calculate some data to return.
""" """
Checks that the file-upload field data contains a valid image (GIF, JPG, def __init__(self, method_attr=None, **kwargs):
PNG, possibly others -- whatever the Python Imaging Library supports). self.method_attr = method_attr
""" kwargs['source'] = '*'
f = super(ImageField, self).from_native(data) kwargs['read_only'] = True
if f is None: super(SerializerMethodField, self).__init__(**kwargs)
return None
from rest_framework.compat import Image
assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'
# We need to get a file object for PIL. We might have a path or we might
# have to read the data into memory.
if hasattr(data, 'temporary_file_path'):
file = data.temporary_file_path()
else:
if hasattr(data, 'read'):
file = BytesIO(data.read())
else:
file = BytesIO(data['content'])
try: def to_representation(self, value):
# load() could spot a truncated JPEG, but it loads the entire method_attr = self.method_attr
# image in memory, which is a DoS vector. See #3848 and #18520. if method_attr is None:
# verify() must be called immediately after the constructor. method_attr = 'get_{field_name}'.format(field_name=self.field_name)
Image.open(file).verify() method = getattr(self.parent, method_attr)
except ImportError: return method(value)
# Under PyPy, it is possible to import PIL. However, the underlying
# _imaging C module isn't available, so an ImportError will be
# raised. Catch and re-raise.
raise
except Exception: # Python Imaging Library doesn't recognize it as an image
raise ValidationError(self.error_messages['invalid_image'])
if hasattr(f, 'seek') and callable(f.seek):
f.seek(0)
return f
class SerializerMethodField(Field): class ModelField(Field):
""" """
A field that gets its value by calling a method on the serializer it's attached to. A generic field that can be used against an arbitrary model field.
This is used by `ModelSerializer` when dealing with custom model fields,
that do not have a serializer field to be mapped to.
""" """
def __init__(self, model_field, **kwargs):
self.model_field = model_field
kwargs['source'] = '*'
super(ModelField, self).__init__(**kwargs)
def __init__(self, method_name, *args, **kwargs): def to_internal_value(self, data):
self.method_name = method_name rel = getattr(self.model_field, 'rel', None)
super(SerializerMethodField, self).__init__(*args, **kwargs) if rel is not None:
return rel.to._meta.get_field(rel.field_name).to_python(data)
return self.model_field.to_python(data)
def field_to_native(self, obj, field_name): def to_representation(self, obj):
value = getattr(self.parent, self.method_name)(obj) value = self.model_field._get_val_from_obj(obj)
return self.to_native(value) if is_protected_type(value):
return value
return self.model_field.value_to_string(obj)
...@@ -3,7 +3,8 @@ Generic views that provide commonly needed behaviour. ...@@ -3,7 +3,8 @@ Generic views that provide commonly needed behaviour.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.db.models.query import QuerySet
from django.core.exceptions import PermissionDenied
from django.core.paginator import Paginator, InvalidPage from django.core.paginator import Paginator, InvalidPage
from django.http import Http404 from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404 from django.shortcuts import get_object_or_404 as _get_object_or_404
...@@ -11,7 +12,6 @@ from django.utils.translation import ugettext as _ ...@@ -11,7 +12,6 @@ from django.utils.translation import ugettext as _
from rest_framework import views, mixins, exceptions from rest_framework import views, mixins, exceptions
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
import warnings
def strict_positive_int(integer_string, cutoff=None): def strict_positive_int(integer_string, cutoff=None):
...@@ -28,7 +28,7 @@ def strict_positive_int(integer_string, cutoff=None): ...@@ -28,7 +28,7 @@ def strict_positive_int(integer_string, cutoff=None):
def get_object_or_404(queryset, *filter_args, **filter_kwargs): def get_object_or_404(queryset, *filter_args, **filter_kwargs):
""" """
Same as Django's standard shortcut, but make sure to raise 404 Same as Django's standard shortcut, but make sure to also raise 404
if the filter_kwargs don't match the required types. if the filter_kwargs don't match the required types.
""" """
try: try:
...@@ -51,11 +51,6 @@ class GenericAPIView(views.APIView): ...@@ -51,11 +51,6 @@ class GenericAPIView(views.APIView):
queryset = None queryset = None
serializer_class = None serializer_class = None
# This shortcut may be used instead of setting either or both
# of the `queryset`/`serializer_class` attributes, although using
# the explicit style is generally preferred.
model = None
# If you want to use object lookups other than pk, set this attribute. # If you want to use object lookups other than pk, set this attribute.
# For more complex lookup requirements override `get_object()`. # For more complex lookup requirements override `get_object()`.
lookup_field = 'pk' lookup_field = 'pk'
...@@ -71,20 +66,10 @@ class GenericAPIView(views.APIView): ...@@ -71,20 +66,10 @@ class GenericAPIView(views.APIView):
# The filter backend classes to use for queryset filtering # The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
# The following attributes may be subject to change, # The following attribute may be subject to change,
# and should be considered private API. # and should be considered private API.
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
paginator_class = Paginator paginator_class = Paginator
######################################
# These are pending deprecation...
pk_url_kwarg = 'pk'
slug_url_kwarg = 'slug'
slug_field = 'slug'
allow_empty = True
filter_backend = api_settings.FILTER_BACKEND
def get_serializer_context(self): def get_serializer_context(self):
""" """
Extra context provided to the serializer class. Extra context provided to the serializer class.
...@@ -95,18 +80,16 @@ class GenericAPIView(views.APIView): ...@@ -95,18 +80,16 @@ class GenericAPIView(views.APIView):
'view': self 'view': self
} }
def get_serializer(self, instance=None, data=None, files=None, many=False, def get_serializer(self, instance=None, data=None, many=False, partial=False):
partial=False, allow_add_remove=False):
""" """
Return the serializer instance that should be used for validating and Return the serializer instance that should be used for validating and
deserializing input, and for serializing output. deserializing input, and for serializing output.
""" """
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
context = self.get_serializer_context() context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files, return serializer_class(
many=many, partial=partial, instance, data=data, many=many, partial=partial, context=context
allow_add_remove=allow_add_remove, )
context=context)
def get_pagination_serializer(self, page): def get_pagination_serializer(self, page):
""" """
...@@ -120,37 +103,16 @@ class GenericAPIView(views.APIView): ...@@ -120,37 +103,16 @@ class GenericAPIView(views.APIView):
context = self.get_serializer_context() context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context) return pagination_serializer_class(instance=page, context=context)
def paginate_queryset(self, queryset, page_size=None): def paginate_queryset(self, queryset):
""" """
Paginate a queryset if required, either returning a page object, Paginate a queryset if required, either returning a page object,
or `None` if pagination is not configured for this view. or `None` if pagination is not configured for this view.
""" """
deprecated_style = False
if page_size is not None:
warnings.warn('The `page_size` parameter to `paginate_queryset()` '
'is deprecated. '
'Note that the return style of this method is also '
'changed, and will simply return a page object '
'when called without a `page_size` argument.',
DeprecationWarning, stacklevel=2)
deprecated_style = True
else:
# Determine the required page size.
# If pagination is not configured, simply return None.
page_size = self.get_paginate_by() page_size = self.get_paginate_by()
if not page_size: if not page_size:
return None return None
if not self.allow_empty: paginator = self.paginator_class(queryset, page_size)
warnings.warn(
'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.',
DeprecationWarning, stacklevel=2
)
paginator = self.paginator_class(queryset, page_size,
allow_empty_first_page=self.allow_empty)
page_kwarg = self.kwargs.get(self.page_kwarg) page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1 page = page_kwarg or page_query_param or 1
...@@ -170,8 +132,6 @@ class GenericAPIView(views.APIView): ...@@ -170,8 +132,6 @@ class GenericAPIView(views.APIView):
'message': str(exc) 'message': str(exc)
}) })
if deprecated_style:
return (paginator, page, page.object_list, page.has_other_pages())
return page return page
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
...@@ -191,29 +151,12 @@ class GenericAPIView(views.APIView): ...@@ -191,29 +151,12 @@ class GenericAPIView(views.APIView):
""" """
Returns the list of filter backends that this view requires. Returns the list of filter backends that this view requires.
""" """
if self.filter_backends is None: return list(self.filter_backends)
filter_backends = []
else:
# Note that we are returning a *copy* of the class attribute,
# so that it is safe for the view to mutate it if needed.
filter_backends = list(self.filter_backends)
if not filter_backends and self.filter_backend:
warnings.warn(
'The `filter_backend` attribute and `FILTER_BACKEND` setting '
'are deprecated in favor of a `filter_backends` '
'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
'a *list* of filter backend classes.',
DeprecationWarning, stacklevel=2
)
filter_backends = [self.filter_backend]
return filter_backends
# The following methods provide default implementations # The following methods provide default implementations
# that you may want to override for more complex cases. # that you may want to override for more complex cases.
def get_paginate_by(self, queryset=None): def get_paginate_by(self):
""" """
Return the size of pages to use with pagination. Return the size of pages to use with pagination.
...@@ -222,11 +165,6 @@ class GenericAPIView(views.APIView): ...@@ -222,11 +165,6 @@ class GenericAPIView(views.APIView):
Otherwise defaults to using `self.paginate_by`. Otherwise defaults to using `self.paginate_by`.
""" """
if queryset is not None:
warnings.warn('The `queryset` parameter to `get_paginate_by()` '
'is deprecated.',
DeprecationWarning, stacklevel=2)
if self.paginate_by_param: if self.paginate_by_param:
try: try:
return strict_positive_int( return strict_positive_int(
...@@ -248,26 +186,13 @@ class GenericAPIView(views.APIView): ...@@ -248,26 +186,13 @@ class GenericAPIView(views.APIView):
(Eg. admins get full serialization, others get basic serialization) (Eg. admins get full serialization, others get basic serialization)
""" """
serializer_class = self.serializer_class assert self.serializer_class is not None, (
if serializer_class is not None: "'%s' should either include a `serializer_class` attribute, "
return serializer_class "or override the `get_serializer_class()` method."
warnings.warn(
'The `.model` attribute on view classes is now deprecated in favor '
'of the more explicit `serializer_class` and `queryset` attributes.',
DeprecationWarning, stacklevel=2
)
assert self.model is not None, \
"'%s' should either include a 'serializer_class' attribute, " \
"or use the 'model' attribute as a shortcut for " \
"automatically generating a serializer class." \
% self.__class__.__name__ % self.__class__.__name__
)
class DefaultSerializer(self.model_serializer_class): return self.serializer_class
class Meta:
model = self.model
return DefaultSerializer
def get_queryset(self): def get_queryset(self):
""" """
...@@ -284,21 +209,19 @@ class GenericAPIView(views.APIView): ...@@ -284,21 +209,19 @@ class GenericAPIView(views.APIView):
(Eg. return a list of items that is specific to the user) (Eg. return a list of items that is specific to the user)
""" """
if self.queryset is not None: assert self.queryset is not None, (
return self.queryset._clone() "'%s' should either include a `queryset` attribute, "
"or override the `get_queryset()` method."
if self.model is not None: % self.__class__.__name__
warnings.warn(
'The `.model` attribute on view classes is now deprecated in favor '
'of the more explicit `serializer_class` and `queryset` attributes.',
DeprecationWarning, stacklevel=2
) )
return self.model._default_manager.all()
error_format = "'%s' must define 'queryset' or 'model'" queryset = self.queryset
raise ImproperlyConfigured(error_format % self.__class__.__name__) if isinstance(queryset, QuerySet):
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return queryset
def get_object(self, queryset=None): def get_object(self):
""" """
Returns the object the view is displaying. Returns the object the view is displaying.
...@@ -306,43 +229,19 @@ class GenericAPIView(views.APIView): ...@@ -306,43 +229,19 @@ class GenericAPIView(views.APIView):
queryset lookups. Eg if objects are referenced using multiple queryset lookups. Eg if objects are referenced using multiple
keyword arguments in the url conf. keyword arguments in the url conf.
""" """
# Determine the base queryset to use.
if queryset is None:
queryset = self.filter_queryset(self.get_queryset()) queryset = self.filter_queryset(self.get_queryset())
else:
pass # Deprecation warning
# Perform the lookup filtering. # Perform the lookup filtering.
# Note that `pk` and `slug` are deprecated styles of lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup = self.kwargs.get(lookup_url_kwarg, None)
pk = self.kwargs.get(self.pk_url_kwarg, None) assert lookup_url_kwarg in self.kwargs, (
slug = self.kwargs.get(self.slug_url_kwarg, None)
if lookup is not None:
filter_kwargs = {self.lookup_field: lookup}
elif pk is not None and self.lookup_field == 'pk':
warnings.warn(
'The `pk_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead',
DeprecationWarning
)
filter_kwargs = {'pk': pk}
elif slug is not None and self.lookup_field == 'pk':
warnings.warn(
'The `slug_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead',
DeprecationWarning
)
filter_kwargs = {self.slug_field: slug}
else:
raise ImproperlyConfigured(
'Expected view %s to be called with a URL keyword argument ' 'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' % 'attribute on the view correctly.' %
(self.__class__.__name__, self.lookup_field) (self.__class__.__name__, lookup_url_kwarg)
) )
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = get_object_or_404(queryset, **filter_kwargs) obj = get_object_or_404(queryset, **filter_kwargs)
# May raise a permission denied # May raise a permission denied
...@@ -355,34 +254,6 @@ class GenericAPIView(views.APIView): ...@@ -355,34 +254,6 @@ class GenericAPIView(views.APIView):
# #
# The are not called by GenericAPIView directly, # The are not called by GenericAPIView directly,
# but are used by the mixin methods. # but are used by the mixin methods.
def pre_save(self, obj):
"""
Placeholder method for calling before saving an object.
May be used to set attributes on the object that are implicit
in either the request, or the url.
"""
pass
def post_save(self, obj, created=False):
"""
Placeholder method for calling after saving an object.
"""
pass
def pre_delete(self, obj):
"""
Placeholder method for calling before deleting an object.
"""
pass
def post_delete(self, obj):
"""
Placeholder method for calling after deleting an object.
"""
pass
def metadata(self, request): def metadata(self, request):
""" """
Return a dictionary of metadata about the view. Return a dictionary of metadata about the view.
...@@ -540,25 +411,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, ...@@ -540,25 +411,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs) return self.destroy(request, *args, **kwargs)
# Deprecated classes
class MultipleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs):
warnings.warn(
'Subclassing `MultipleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.',
DeprecationWarning, stacklevel=2
)
super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
class SingleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs):
warnings.warn(
'Subclassing `SingleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.',
DeprecationWarning, stacklevel=2
)
super(SingleObjectAPIView, self).__init__(*args, **kwargs)
...@@ -6,40 +6,11 @@ which allows mixin classes to be composed in interesting ways. ...@@ -6,40 +6,11 @@ which allows mixin classes to be composed in interesting ways.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import ValidationError
from django.http import Http404 from django.http import Http404
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
import warnings
def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
For use when performing full_clean on a model instance,
so we only clean the required fields.
"""
include = []
if pk:
# Deprecated
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name)
if slug_field:
# Deprecated
include.append(slug_field)
if lookup_field and lookup_field != 'pk':
include.append(lookup_field)
return [field.name for field in obj._meta.fields if field.name not in include]
class CreateModelMixin(object): class CreateModelMixin(object):
...@@ -47,17 +18,11 @@ class CreateModelMixin(object): ...@@ -47,17 +18,11 @@ class CreateModelMixin(object):
Create a model instance. Create a model instance.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES) serializer = self.get_serializer(data=request.DATA)
serializer.is_valid(raise_exception=True)
if serializer.is_valid(): serializer.save()
self.pre_save(serializer.object)
self.object = serializer.save(force_insert=True)
self.post_save(self.object, created=True)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
headers=headers)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_success_headers(self, data): def get_success_headers(self, data):
try: try:
...@@ -70,31 +35,13 @@ class ListModelMixin(object): ...@@ -70,31 +35,13 @@ class ListModelMixin(object):
""" """
List a queryset. List a queryset.
""" """
empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
self.object_list = self.filter_queryset(self.get_queryset()) instance = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(instance)
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
if not self.allow_empty and not self.object_list:
warnings.warn(
'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.',
DeprecationWarning
)
class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name}
raise Http404(error_msg)
# Switch between paginated or standard style responses
page = self.paginate_queryset(self.object_list)
if page is not None: if page is not None:
serializer = self.get_pagination_serializer(page) serializer = self.get_pagination_serializer(page)
else: else:
serializer = self.get_serializer(self.object_list, many=True) serializer = self.get_serializer(instance, many=True)
return Response(serializer.data) return Response(serializer.data)
...@@ -103,8 +50,8 @@ class RetrieveModelMixin(object): ...@@ -103,8 +50,8 @@ class RetrieveModelMixin(object):
Retrieve a model instance. Retrieve a model instance.
""" """
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
self.object = self.get_object() instance = self.get_object()
serializer = self.get_serializer(self.object) serializer = self.get_serializer(instance)
return Response(serializer.data) return Response(serializer.data)
...@@ -114,29 +61,52 @@ class UpdateModelMixin(object): ...@@ -114,29 +61,52 @@ class UpdateModelMixin(object):
""" """
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
self.object = self.get_object_or_none() instance = self.get_object()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
serializer = self.get_serializer(self.object, data=request.DATA, def partial_update(self, request, *args, **kwargs):
files=request.FILES, partial=partial) kwargs['partial'] = True
return self.update(request, *args, **kwargs)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
try: class DestroyModelMixin(object):
self.pre_save(serializer.object) """
except ValidationError as err: Destroy a model instance.
# full_clean on model instance may be called in pre_save, """
# so we have to handle eventual errors. def destroy(self, request, *args, **kwargs):
return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) instance = self.get_object()
instance.delete()
if self.object is None: return Response(status=status.HTTP_204_NO_CONTENT)
self.object = serializer.save(force_insert=True)
self.post_save(self.object, created=True)
# The AllowPUTAsCreateMixin was previously the default behaviour
# for PUT requests. This has now been removed and must be *explictly*
# included if it is the behavior that you want.
# For more info see: ...
class AllowPUTAsCreateMixin(object):
"""
The following mixin class may be used in order to support PUT-as-create
behavior for incoming requests.
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object_or_none()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
serializer.is_valid(raise_exception=True)
if instance is None:
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg]
extras = {self.lookup_field: lookup_value}
serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
self.object = serializer.save(force_update=True) serializer.save()
self.post_save(self.object, created=False) return Response(serializer.data)
return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True kwargs['partial'] = True
...@@ -156,41 +126,3 @@ class UpdateModelMixin(object): ...@@ -156,41 +126,3 @@ class UpdateModelMixin(object):
# PATCH requests where the object does not exist should still # PATCH requests where the object does not exist should still
# return a 404 response. # return a 404 response.
raise raise
def pre_save(self, obj):
"""
Set any attributes on the object that are implicit in the request.
"""
# pk and/or slug attributes are implicit in the URL.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup = self.kwargs.get(lookup_url_kwarg, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
slug_field = slug and self.slug_field or None
if lookup:
setattr(obj, self.lookup_field, lookup)
if pk:
setattr(obj, 'pk', pk)
if slug:
setattr(obj, slug_field, slug)
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)
obj.full_clean(exclude)
class DestroyModelMixin(object):
"""
Destroy a model instance.
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
self.pre_delete(obj)
obj.delete()
self.post_delete(obj)
return Response(status=status.HTTP_204_NO_CONTENT)
...@@ -13,7 +13,7 @@ class NextPageField(serializers.Field): ...@@ -13,7 +13,7 @@ class NextPageField(serializers.Field):
""" """
page_field = 'page' page_field = 'page'
def to_native(self, value): def to_representation(self, value):
if not value.has_next(): if not value.has_next():
return None return None
page = value.next_page_number() page = value.next_page_number()
...@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field): ...@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):
""" """
page_field = 'page' page_field = 'page'
def to_native(self, value): def to_representation(self, value):
if not value.has_previous(): if not value.has_previous():
return None return None
page = value.previous_page_number() page = value.previous_page_number()
...@@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field): ...@@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page) return replace_query_param(url, self.page_field, page)
class DefaultObjectSerializer(serializers.Field): class DefaultObjectSerializer(serializers.ReadOnlyField):
""" """
If no object serializer is specified, then this serializer will be applied If no object serializer is specified, then this serializer will be applied
as the default. as the default.
...@@ -49,25 +49,11 @@ class DefaultObjectSerializer(serializers.Field): ...@@ -49,25 +49,11 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source) super(DefaultObjectSerializer, self).__init__(source=source)
class PaginationSerializerOptions(serializers.SerializerOptions):
"""
An object that stores the options that may be provided to a
pagination serializer by using the inner `Meta` class.
Accessible on the instance as `serializer.opts`.
"""
def __init__(self, meta):
super(PaginationSerializerOptions, self).__init__(meta)
self.object_serializer_class = getattr(meta, 'object_serializer_class',
DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer): class BasePaginationSerializer(serializers.Serializer):
""" """
A base class for pagination serializers to inherit from, A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy. to make implementing custom serializers more easy.
""" """
_options_class = PaginationSerializerOptions
results_field = 'results' results_field = 'results'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -76,22 +62,23 @@ class BasePaginationSerializer(serializers.Serializer): ...@@ -76,22 +62,23 @@ class BasePaginationSerializer(serializers.Serializer):
""" """
super(BasePaginationSerializer, self).__init__(*args, **kwargs) super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field results_field = self.results_field
object_serializer = self.opts.object_serializer_class
if 'context' in kwargs: try:
context_kwarg = {'context': kwargs['context']} object_serializer = self.Meta.object_serializer_class
else: except AttributeError:
context_kwarg = {} object_serializer = DefaultObjectSerializer
self.fields[results_field] = object_serializer(source='object_list', self.fields[results_field] = serializers.ListSerializer(
many=True, child=object_serializer(),
**context_kwarg) source='object_list'
)
self.fields[results_field].bind(results_field, self, self)
class PaginationSerializer(BasePaginationSerializer): class PaginationSerializer(BasePaginationSerializer):
""" """
A default implementation of a pagination serializer. A default implementation of a pagination serializer.
""" """
count = serializers.Field(source='paginator.count') count = serializers.ReadOnlyField(source='paginator.count')
next = NextPageField(source='*') next = NextPageField(source='*')
previous = PreviousPageField(source='*') previous = PreviousPageField(source='*')
...@@ -11,7 +11,7 @@ from django.http import QueryDict ...@@ -11,7 +11,7 @@ from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
from django.utils import six from django.utils import six
from rest_framework.compat import etree, yaml, force_text from rest_framework.compat import etree, yaml, force_text, urlparse
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework import renderers from rest_framework import renderers
import json import json
...@@ -48,7 +48,7 @@ class JSONParser(BaseParser): ...@@ -48,7 +48,7 @@ class JSONParser(BaseParser):
""" """
media_type = 'application/json' media_type = 'application/json'
renderer_class = renderers.UnicodeJSONRenderer renderer_class = renderers.JSONRenderer
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
""" """
...@@ -290,6 +290,22 @@ class FileUploadParser(BaseParser): ...@@ -290,6 +290,22 @@ class FileUploadParser(BaseParser):
try: try:
meta = parser_context['request'].META meta = parser_context['request'].META
disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8')) disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8'))
return force_text(disposition[1]['filename']) filename_parm = disposition[1]
if 'filename*' in filename_parm:
return self.get_encoded_filename(filename_parm)
return force_text(filename_parm['filename'])
except (AttributeError, KeyError): except (AttributeError, KeyError):
pass pass
def get_encoded_filename(self, filename_parm):
"""
Handle encoded filenames per RFC6266. See also:
http://tools.ietf.org/html/rfc2231#section-4
"""
encoded_filename = force_text(filename_parm['filename*'])
try:
charset, lang, filename = encoded_filename.split('\'', 2)
filename = urlparse.unquote(filename)
except (ValueError, LookupError):
filename = force_text(filename_parm['filename'])
return filename
""" from rest_framework.compat import smart_text, urlparse
Serializer fields that deal with relationships. from rest_framework.fields import Field
These fields allow you to specify the style that should be used to represent
model relationships, including hyperlinks, primary keys, or slugs.
"""
from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from django import forms
from django.db.models.fields import BLANK_CHOICE_DASH
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
from django.utils.translation import ugettext_lazy as _
from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.compat import urlparse from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured
from rest_framework.compat import smart_text from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404
import warnings from django.db.models.query import QuerySet
from django.utils.translation import ugettext_lazy as _
# Relational fields
# Not actually Writable, but subclasses may need to be.
class RelatedField(WritableField):
"""
Base class for related model fields.
This represents a relationship using the unicode representation of the target.
"""
widget = widgets.Select
many_widget = widgets.SelectMultiple
form_field_class = forms.ChoiceField
many_form_field_class = forms.MultipleChoiceField
null_values = (None, '', 'None')
cache_choices = False
empty_label = None
read_only = True
many = False
def __init__(self, *args, **kwargs):
queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', self.many)
if self.many:
self.widget = self.many_widget
self.form_field_class = self.many_form_field_class
kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
super(RelatedField, self).__init__(*args, **kwargs)
if not self.required:
# Accessed in ModelChoiceIterator django/forms/models.py:1034
# If set adds empty choice.
self.empty_label = BLANK_CHOICE_DASH[0][1]
self.queryset = queryset
def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only:
manager = getattr(self.parent.opts.model, self.source or field_name)
if hasattr(manager, 'related'): # Forward
self.queryset = manager.related.model._default_manager.all()
else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all()
# We need this stuff to make form choices work...
def prepare_value(self, obj):
return self.to_native(obj)
def label_from_instance(self, obj):
"""
Return a readable representation for use with eg. select widgets.
"""
desc = smart_text(obj)
ident = smart_text(self.to_native(obj))
if desc == ident:
return desc
return "%s - %s" % (desc, ident)
def _get_queryset(self):
return self._queryset
def _set_queryset(self, queryset):
self._queryset = queryset
self.widget.choices = self.choices
queryset = property(_get_queryset, _set_queryset)
def _get_choices(self):
# If self._choices is set, then somebody must have manually set
# the property self.choices. In this case, just return self._choices.
if hasattr(self, '_choices'):
return self._choices
# Otherwise, execute the QuerySet in self.queryset to determine the
# choices dynamically. Return a fresh ModelChoiceIterator that has not been
# consumed. Note that we're instantiating a new ModelChoiceIterator *each*
# time _get_choices() is called (and, thus, each time self.choices is
# accessed) so that we can ensure the QuerySet has not been consumed. This
# construct might look complicated but it allows for lazy evaluation of
# the queryset.
return ModelChoiceIterator(self)
def _set_choices(self, value):
# Setting choices also sets the choices on the widget.
# choices can be any iterable, but we call list() on it because
# it will be consumed more than once.
self._choices = self.widget.choices = list(value)
choices = property(_get_choices, _set_choices)
# Default value handling
def get_default_value(self):
default = super(RelatedField, self).get_default_value()
if self.many and default is None:
return []
return default
# Regular serializer stuff...
def field_to_native(self, obj, field_name):
try:
if self.source == '*':
return self.to_native(obj)
source = self.source or field_name
value = obj
for component in source.split('.'):
if value is None:
break
value = get_component(value, component)
except ObjectDoesNotExist:
return None
if value is None:
return None
if self.many:
if is_simple_callable(getattr(value, 'all', None)):
return [self.to_native(item) for item in value.all()]
else:
# Also support non-queryset iterables.
# This allows us to also support plain lists of related items.
return [self.to_native(item) for item in value]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
try:
if self.many:
try:
# Form data
value = data.getlist(field_name)
if value == [''] or value == []:
raise KeyError
except AttributeError:
# Non-form data
value = data[field_name]
else:
value = data[field_name]
except KeyError:
if self.partial:
return
value = self.get_default_value()
if value in self.null_values:
if self.required:
raise ValidationError(self.error_messages['required'])
into[(self.source or field_name)] = None
elif self.many:
into[(self.source or field_name)] = [self.from_native(item) for item in value]
else:
into[(self.source or field_name)] = self.from_native(value)
# PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField): class RelatedField(Field):
""" def __init__(self, **kwargs):
Represents a relationship as a pk value. self.queryset = kwargs.pop('queryset', None)
""" assert self.queryset is not None or kwargs.get('read_only', None), (
read_only = False 'Relational field must provide a `queryset` argument, '
'or set read_only=`True`.'
)
assert not (self.queryset is not None and kwargs.get('read_only', None)), (
'Relational fields should not provide a `queryset` argument, '
'when setting read_only=`True`.'
)
super(RelatedField, self).__init__(**kwargs)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ManyRelation` classes instead when `many=True` is set.
if kwargs.pop('many', False):
return ManyRelation(
child_relation=cls(*args, **kwargs),
read_only=kwargs.get('read_only', False)
)
return super(RelatedField, cls).__new__(cls, *args, **kwargs)
default_error_messages = { def get_queryset(self):
'does_not_exist': _("Invalid pk '%s' - object does not exist."), queryset = self.queryset
'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), if isinstance(queryset, QuerySet):
} # Ensure queryset is re-evaluated whenever used.
queryset = queryset.all()
return queryset
# TODO: Remove these field hacks...
def prepare_value(self, obj):
return self.to_native(obj.pk)
def label_from_instance(self, obj): class StringRelatedField(Field):
""" """
Return a readable representation for use with eg. select widgets. A read only field that represents its targets using their
plain string representation.
""" """
desc = smart_text(obj)
ident = smart_text(self.to_native(obj.pk))
if desc == ident:
return desc
return "%s - %s" % (desc, ident)
# TODO: Possibly change this to just take `obj`, through prob less performant def __init__(self, **kwargs):
def to_native(self, pk): kwargs['read_only'] = True
return pk super(StringRelatedField, self).__init__(**kwargs)
def from_native(self, data): def to_representation(self, value):
if self.queryset is None: return str(value)
raise Exception('Writable related fields must include a `queryset` argument')
try:
return self.queryset.get(pk=data)
except ObjectDoesNotExist:
msg = self.error_messages['does_not_exist'] % smart_text(data)
raise ValidationError(msg)
except (TypeError, ValueError):
received = type(data).__name__
msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg)
def field_to_native(self, obj, field_name):
if self.many:
# To-many relationship
queryset = None
if not self.source:
# Prefer obj.serializable_value for performance reasons
try:
queryset = obj.serializable_value(field_name)
except AttributeError:
pass
if queryset is None:
# RelatedManager (reverse relationship)
source = self.source or field_name
queryset = obj
for component in source.split('.'):
if queryset is None:
return []
queryset = get_component(queryset, component)
# Forward relationship
if is_simple_callable(getattr(queryset, 'all', None)):
return [self.to_native(item.pk) for item in queryset.all()]
else:
# Also support non-queryset iterables.
# This allows us to also support plain lists of related items.
return [self.to_native(item.pk) for item in queryset]
# To-one relationship
try:
# Prefer obj.serializable_value for performance reasons
pk = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedObject (reverse relationship)
try:
pk = getattr(obj, self.source or field_name).pk
except (ObjectDoesNotExist, AttributeError):
return None
# Forward relationship
return self.to_native(pk)
# Slug relationships
class SlugRelatedField(RelatedField):
"""
Represents a relationship using a unique field on the target.
"""
read_only = False
class PrimaryKeyRelatedField(RelatedField):
default_error_messages = { default_error_messages = {
'does_not_exist': _("Object with %s=%s does not exist."), 'required': 'This field is required.',
'invalid': _('Invalid value.'), 'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.",
'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',
} }
def __init__(self, *args, **kwargs): def to_internal_value(self, data):
self.slug_field = kwargs.pop('slug_field', None)
assert self.slug_field, 'slug_field is required'
super(SlugRelatedField, self).__init__(*args, **kwargs)
def to_native(self, obj):
return getattr(obj, self.slug_field)
def from_native(self, data):
if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try: try:
return self.queryset.get(**{self.slug_field: data}) return self.get_queryset().get(pk=data)
except ObjectDoesNotExist: except ObjectDoesNotExist:
raise ValidationError(self.error_messages['does_not_exist'] % self.fail('does_not_exist', pk_value=data)
(self.slug_field, smart_text(data)))
except (TypeError, ValueError): except (TypeError, ValueError):
msg = self.error_messages['invalid'] self.fail('incorrect_type', data_type=type(data).__name__)
raise ValidationError(msg)
def to_representation(self, value):
return value.pk
# Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField): class HyperlinkedRelatedField(RelatedField):
"""
Represents a relationship using hyperlinking.
"""
read_only = False
lookup_field = 'pk' lookup_field = 'pk'
default_error_messages = { default_error_messages = {
'no_match': _('Invalid hyperlink - No URL match'), 'required': 'This field is required.',
'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), 'no_match': 'Invalid hyperlink - No URL match',
'configuration_error': _('Invalid hyperlink due to configuration error'), 'incorrect_match': 'Invalid hyperlink - Incorrect URL match.',
'does_not_exist': _("Invalid hyperlink - object does not exist."), 'does_not_exist': 'Invalid hyperlink - Object does not exist.',
'incorrect_type': _('Incorrect type. Expected url string, received %s.'), 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
} }
# These are all deprecated def __init__(self, view_name=None, **kwargs):
pk_url_kwarg = 'pk' assert view_name is not None, 'The `view_name` argument is required.'
slug_field = 'slug' self.view_name = view_name
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except KeyError:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
self.format = kwargs.pop('format', None) self.format = kwargs.pop('format', None)
# These are deprecated # We include these simply for dependancy injection in tests.
if 'pk_url_kwarg' in kwargs: # We can't add them as class attributes or they would expect an
msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' # implict `self` argument to be passed.
warnings.warn(msg, DeprecationWarning, stacklevel=2) self.reverse = reverse
if 'slug_url_kwarg' in kwargs: self.resolve = resolve
msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
warnings.warn(msg, DeprecationWarning, stacklevel=2) super(HyperlinkedRelatedField, self).__init__(**kwargs)
if 'slug_field' in kwargs:
msg = 'slug_field is deprecated. Use lookup_field instead.'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) def get_object(self, view_name, view_args, view_kwargs):
self.slug_field = kwargs.pop('slug_field', self.slug_field) """
default_slug_kwarg = self.slug_url_kwarg or self.slug_field Return the object corresponding to a matched URL.
self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) Takes the matched URL conf arguments, and should return an
object instance, or raise an `ObjectDoesNotExist` exception.
"""
lookup_value = view_kwargs[self.lookup_url_kwarg]
lookup_kwargs = {self.lookup_field: lookup_value}
return self.get_queryset().get(**lookup_kwargs)
def get_url(self, obj, view_name, request, format): def get_url(self, obj, view_name, request, format):
""" """
...@@ -359,176 +115,48 @@ class HyperlinkedRelatedField(RelatedField): ...@@ -359,176 +115,48 @@ class HyperlinkedRelatedField(RelatedField):
May raise a `NoReverseMatch` if the `view_name` and `lookup_field` May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf. attributes are not configured to correctly match the URL conf.
""" """
lookup_field = getattr(obj, self.lookup_field) # Unsaved objects will not yet have a valid URL.
kwargs = {self.lookup_field: lookup_field} if obj.pk is None:
try: return None
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
if self.pk_url_kwarg != 'pk':
# Only try pk if it has been explicitly set.
# Otherwise, the default `lookup_field = 'pk'` has us covered.
pk = obj.pk
kwargs = {self.pk_url_kwarg: pk}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
slug = getattr(obj, self.slug_field, None)
if slug is not None:
# Only try slug if it corresponds to an attribute on the object.
kwargs = {self.slug_url_kwarg: slug}
try:
ret = reverse(view_name, kwargs=kwargs, request=request, format=format)
if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug':
# If the lookup succeeds using the default slug params,
# then `slug_field` is being used implicitly, and we
# we need to warn about the pending deprecation.
msg = 'Implicit slug field hyperlinked fields are deprecated.' \
'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return ret
except NoReverseMatch:
pass
raise NoReverseMatch()
def get_object(self, queryset, view_name, view_args, view_kwargs):
"""
Return the object corresponding to a matched URL.
Takes the matched URL conf arguments, and the queryset, and should
return an object instance, or raise an `ObjectDoesNotExist` exception.
"""
lookup = view_kwargs.get(self.lookup_field, None)
pk = view_kwargs.get(self.pk_url_kwarg, None)
slug = view_kwargs.get(self.slug_url_kwarg, None)
if lookup is not None:
filter_kwargs = {self.lookup_field: lookup}
elif pk is not None:
filter_kwargs = {'pk': pk}
elif slug is not None:
filter_kwargs = {self.slug_field: slug}
else:
raise ObjectDoesNotExist()
return queryset.get(**filter_kwargs)
def to_native(self, obj):
view_name = self.view_name
request = self.context.get('request', None)
format = self.format or self.context.get('format', None)
assert request is not None, (
"`HyperlinkedRelatedField` requires the request in the serializer "
"context. Add `context={'request': request}` when instantiating "
"the serializer."
)
# If the object has not yet been saved then we cannot hyperlink to it.
if getattr(obj, 'pk', None) is None:
return
# Return the hyperlink, or error if incorrectly configured.
try:
return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
msg = (
'Could not resolve URL for hyperlinked relationship using '
'view name "%s". You may have failed to include the related '
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
)
raise Exception(msg % view_name)
def from_native(self, value): lookup_value = getattr(obj, self.lookup_field)
# Convert URL -> model instance pk kwargs = {self.lookup_url_kwarg: lookup_value}
# TODO: Use values_list return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
queryset = self.queryset
if queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
def to_internal_value(self, data):
try: try:
http_prefix = value.startswith(('http:', 'https:')) http_prefix = data.startswith(('http:', 'https:'))
except AttributeError: except AttributeError:
msg = self.error_messages['incorrect_type'] self.fail('incorrect_type', data_type=type(data).__name__)
raise ValidationError(msg % type(value).__name__)
if http_prefix: if http_prefix:
# If needed convert absolute URLs to relative path # If needed convert absolute URLs to relative path
value = urlparse.urlparse(value).path data = urlparse.urlparse(data).path
prefix = get_script_prefix() prefix = get_script_prefix()
if value.startswith(prefix): if data.startswith(prefix):
value = '/' + value[len(prefix):] data = '/' + data[len(prefix):]
try: try:
match = resolve(value) match = self.resolve(data)
except Exception: except Resolver404:
raise ValidationError(self.error_messages['no_match']) self.fail('no_match')
if match.view_name != self.view_name: if match.view_name != self.view_name:
raise ValidationError(self.error_messages['incorrect_match']) self.fail('incorrect_match')
try: try:
return self.get_object(queryset, match.view_name, return self.get_object(match.view_name, match.args, match.kwargs)
match.args, match.kwargs)
except (ObjectDoesNotExist, TypeError, ValueError): except (ObjectDoesNotExist, TypeError, ValueError):
raise ValidationError(self.error_messages['does_not_exist']) self.fail('does_not_exist')
class HyperlinkedIdentityField(Field):
"""
Represents the instance, or a property on the instance, using hyperlinking.
"""
lookup_field = 'pk'
read_only = True
# These are all deprecated def to_representation(self, value):
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except KeyError:
msg = "HyperlinkedIdentityField requires 'view_name' argument"
raise ValueError(msg)
self.format = kwargs.pop('format', None)
lookup_field = kwargs.pop('lookup_field', None)
self.lookup_field = lookup_field or self.lookup_field
# These are deprecated
if 'pk_url_kwarg' in kwargs:
msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_url_kwarg' in kwargs:
msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_field' in kwargs:
msg = 'slug_field is deprecated. Use lookup_field instead.'
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
request = self.context.get('request', None) request = self.context.get('request', None)
format = self.context.get('format', None) format = self.context.get('format', None)
view_name = self.view_name
assert request is not None, ( assert request is not None, (
"`HyperlinkedIdentityField` requires the request in the serializer" "`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating " " context. Add `context={'request': request}` when instantiating "
"the serializer." "the serializer." % self.__class__.__name__
) )
# By default use whatever format is given for the current context # By default use whatever format is given for the current context
...@@ -545,7 +173,7 @@ class HyperlinkedIdentityField(Field): ...@@ -545,7 +173,7 @@ class HyperlinkedIdentityField(Field):
# Return the hyperlink, or error if incorrectly configured. # Return the hyperlink, or error if incorrectly configured.
try: try:
return self.get_url(obj, view_name, request, format) return self.get_url(value, self.view_name, request, format)
except NoReverseMatch: except NoReverseMatch:
msg = ( msg = (
'Could not resolve URL for hyperlinked relationship using ' 'Could not resolve URL for hyperlinked relationship using '
...@@ -553,43 +181,81 @@ class HyperlinkedIdentityField(Field): ...@@ -553,43 +181,81 @@ class HyperlinkedIdentityField(Field):
'model in your API, or incorrectly configured the ' 'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.' '`lookup_field` attribute on this field.'
) )
raise Exception(msg % view_name) raise ImproperlyConfigured(msg % self.view_name)
def get_url(self, obj, view_name, request, format):
class HyperlinkedIdentityField(HyperlinkedRelatedField):
""" """
Given an object, return the URL that hyperlinks to the object. A read-only field that represents the identity URL for an object, itself.
May raise a `NoReverseMatch` if the `view_name` and `lookup_field` This is in contrast to `HyperlinkedRelatedField` which represents the
attributes are not configured to correctly match the URL conf. URL of relationships to other objects.
""" """
lookup_field = getattr(obj, self.lookup_field, None)
kwargs = {self.lookup_field: lookup_field}
# Handle unsaved object case def __init__(self, view_name=None, **kwargs):
if lookup_field is None: assert view_name is not None, 'The `view_name` argument is required.'
return None kwargs['read_only'] = True
kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
if self.pk_url_kwarg != 'pk': class SlugRelatedField(RelatedField):
# Only try pk lookup if it has been explicitly set. """
# Otherwise, the default `lookup_field = 'pk'` has us covered. A read-write field the represents the target of the relationship
kwargs = {self.pk_url_kwarg: obj.pk} by a unique 'slug' attribute.
try: """
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
slug = getattr(obj, self.slug_field, None) default_error_messages = {
if slug: 'does_not_exist': _("Object with {slug_name}={value} does not exist."),
# Only use slug lookup if a slug field exists on the model 'invalid': _('Invalid value.'),
kwargs = {self.slug_url_kwarg: slug} }
def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs)
def to_internal_value(self, data):
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return self.get_queryset().get(**{self.slug_field: data})
except NoReverseMatch: except ObjectDoesNotExist:
pass self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data))
except (TypeError, ValueError):
self.fail('invalid')
def to_representation(self, obj):
return getattr(obj, self.slug_field)
class ManyRelation(Field):
"""
Relationships with `many=True` transparently get coerced into instead being
a ManyRelation with a child relationship.
The `ManyRelation` class is responsible for handling iterating through
the values and passing each one to the child relationship.
You shouldn't need to be using this class directly yourself.
"""
raise NoReverseMatch() def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
assert child_relation is not None, '`child_relation` is a required argument.'
super(ManyRelation, self).__init__(*args, **kwargs)
def bind(self, field_name, parent, root):
# ManyRelation needs to provide the current context to the child relation.
super(ManyRelation, self).bind(field_name, parent, root)
self.child_relation.bind(field_name, parent, root)
def to_internal_value(self, data):
return [
self.child_relation.to_internal_value(item)
for item in data
]
def to_representation(self, obj):
return [
self.child_relation.to_representation(value)
for value in obj.all()
]
...@@ -26,6 +26,10 @@ from rest_framework.utils.breadcrumbs import get_breadcrumbs ...@@ -26,6 +26,10 @@ from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework import exceptions, status, VERSION from rest_framework import exceptions, status, VERSION
def zero_as_none(value):
return None if value == 0 else value
class BaseRenderer(object): class BaseRenderer(object):
""" """
All renderers should extend this class, setting the `media_type` All renderers should extend this class, setting the `media_type`
...@@ -44,13 +48,13 @@ class BaseRenderer(object): ...@@ -44,13 +48,13 @@ class BaseRenderer(object):
class JSONRenderer(BaseRenderer): class JSONRenderer(BaseRenderer):
""" """
Renderer which serializes to JSON. Renderer which serializes to JSON.
Applies JSON's backslash-u character escaping for non-ascii characters.
""" """
media_type = 'application/json' media_type = 'application/json'
format = 'json' format = 'json'
encoder_class = encoders.JSONEncoder encoder_class = encoders.JSONEncoder
ensure_ascii = True ensure_ascii = not api_settings.UNICODE_JSON
compact = api_settings.COMPACT_JSON
# We don't set a charset because JSON is a binary encoding, # We don't set a charset because JSON is a binary encoding,
# that can be encoded as utf-8, utf-16 or utf-32. # that can be encoded as utf-8, utf-16 or utf-32.
...@@ -62,9 +66,10 @@ class JSONRenderer(BaseRenderer): ...@@ -62,9 +66,10 @@ class JSONRenderer(BaseRenderer):
if accepted_media_type: if accepted_media_type:
# If the media type looks like 'application/json; indent=4', # If the media type looks like 'application/json; indent=4',
# then pretty print the result. # then pretty print the result.
# Note that we coerce `indent=0` into `indent=None`.
base_media_type, params = parse_header(accepted_media_type.encode('ascii')) base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
try: try:
return max(min(int(params['indent']), 8), 0) return zero_as_none(max(min(int(params['indent']), 8), 0))
except (KeyError, ValueError, TypeError): except (KeyError, ValueError, TypeError):
pass pass
...@@ -81,10 +86,12 @@ class JSONRenderer(BaseRenderer): ...@@ -81,10 +86,12 @@ class JSONRenderer(BaseRenderer):
renderer_context = renderer_context or {} renderer_context = renderer_context or {}
indent = self.get_indent(accepted_media_type, renderer_context) indent = self.get_indent(accepted_media_type, renderer_context)
separators = (',', ':') if (indent is None and self.compact) else (', ', ': ')
ret = json.dumps( ret = json.dumps(
data, cls=self.encoder_class, data, cls=self.encoder_class,
indent=indent, ensure_ascii=self.ensure_ascii indent=indent, ensure_ascii=self.ensure_ascii,
separators=separators
) )
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
...@@ -96,14 +103,6 @@ class JSONRenderer(BaseRenderer): ...@@ -96,14 +103,6 @@ class JSONRenderer(BaseRenderer):
return ret return ret
class UnicodeJSONRenderer(JSONRenderer):
ensure_ascii = False
"""
Renderer which serializes to JSON.
Does *not* apply JSON's character escaping for non-ascii characters.
"""
class JSONPRenderer(JSONRenderer): class JSONPRenderer(JSONRenderer):
""" """
Renderer which serializes to json, Renderer which serializes to json,
...@@ -196,7 +195,7 @@ class YAMLRenderer(BaseRenderer): ...@@ -196,7 +195,7 @@ class YAMLRenderer(BaseRenderer):
format = 'yaml' format = 'yaml'
encoder = encoders.SafeDumper encoder = encoders.SafeDumper
charset = 'utf-8' charset = 'utf-8'
ensure_ascii = True ensure_ascii = False
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
""" """
...@@ -210,14 +209,6 @@ class YAMLRenderer(BaseRenderer): ...@@ -210,14 +209,6 @@ class YAMLRenderer(BaseRenderer):
return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii) return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
class UnicodeYAMLRenderer(YAMLRenderer):
"""
Renderer which serializes to YAML.
Does *not* apply character escaping for non-ascii characters.
"""
ensure_ascii = False
class TemplateHTMLRenderer(BaseRenderer): class TemplateHTMLRenderer(BaseRenderer):
""" """
An HTML renderer for use with templates. An HTML renderer for use with templates.
...@@ -436,13 +427,13 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -436,13 +427,13 @@ class BrowsableAPIRenderer(BaseRenderer):
if request.method == method: if request.method == method:
try: try:
data = request.DATA data = request.DATA
files = request.FILES # files = request.FILES
except ParseError: except ParseError:
data = None data = None
files = None # files = None
else: else:
data = None data = None
files = None # files = None
with override_method(view, request, method) as request: with override_method(view, request, method) as request:
obj = getattr(view, 'object', None) obj = getattr(view, 'object', None)
...@@ -458,7 +449,7 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -458,7 +449,7 @@ class BrowsableAPIRenderer(BaseRenderer):
): ):
return return
serializer = view.get_serializer(instance=obj, data=data, files=files) serializer = view.get_serializer(instance=obj, data=data)
serializer.is_valid() serializer.is_valid()
data = serializer.data data = serializer.data
...@@ -579,10 +570,10 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -579,10 +570,10 @@ class BrowsableAPIRenderer(BaseRenderer):
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers, 'response_headers': response_headers,
'put_form': self.get_rendered_html_form(view, 'PUT', request), # 'put_form': self.get_rendered_html_form(view, 'PUT', request),
'post_form': self.get_rendered_html_form(view, 'POST', request), # 'post_form': self.get_rendered_html_form(view, 'POST', request),
'delete_form': self.get_rendered_html_form(view, 'DELETE', request), # 'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), # 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
'raw_data_put_form': raw_data_put_form, 'raw_data_put_form': raw_data_put_form,
'raw_data_post_form': raw_data_post_form, 'raw_data_post_form': raw_data_post_form,
......
...@@ -19,6 +19,7 @@ import itertools ...@@ -19,6 +19,7 @@ import itertools
from collections import namedtuple from collections import namedtuple
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import NoReverseMatch
from rest_framework import views from rest_framework import views
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
...@@ -284,10 +285,19 @@ class DefaultRouter(SimpleRouter): ...@@ -284,10 +285,19 @@ class DefaultRouter(SimpleRouter):
class APIRoot(views.APIView): class APIRoot(views.APIView):
_ignore_model_permissions = True _ignore_model_permissions = True
def get(self, request, format=None): def get(self, request, *args, **kwargs):
ret = {} ret = {}
for key, url_name in api_root_dict.items(): for key, url_name in api_root_dict.items():
ret[key] = reverse(url_name, request=request, format=format) try:
ret[key] = reverse(
url_name,
request=request,
format=kwargs.get('format', None)
)
except NoReverseMatch:
# Don't bail out if eg. no list routes exist, only detail routes.
continue
return Response(ret) return Response(ret)
return APIRoot.as_view() return APIRoot.as_view()
......
...@@ -10,21 +10,20 @@ python primitives. ...@@ -10,21 +10,20 @@ python primitives.
2. The process of marshalling between python primitives and request and 2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers. response content is handled by parsers and renderers.
""" """
from __future__ import unicode_literals from django.core.exceptions import ImproperlyConfigured, ValidationError
import copy
import datetime
import inspect
import types
from decimal import Decimal
from django.contrib.contenttypes.generic import GenericForeignKey
from django.core.paginator import Page
from django.db import models from django.db import models
from django.forms import widgets
from django.utils import six from django.utils import six
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.core.exceptions import ObjectDoesNotExist from collections import namedtuple
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation
from rest_framework.utils.field_mapping import (
get_url_kwargs, get_field_kwargs,
get_relation_kwargs, get_nested_relation_kwargs,
lookup_class
)
import copy
# Note: We do the following so that users of the framework can use this style: # Note: We do the following so that users of the framework can use this style:
# #
...@@ -37,1107 +36,453 @@ from rest_framework.relations import * # NOQA ...@@ -37,1107 +36,453 @@ from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA from rest_framework.fields import * # NOQA
def _resolve_model(obj): FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format: class BaseSerializer(Field):
'appname.ModelName' """
The BaseSerializer class provides a minimal class which may be used
for writing custom serializer implementations.
""" """
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
else:
raise ValueError("{0} is not a Django model".format(obj))
def __init__(self, instance=None, data=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance
self._initial_data = data
def pretty_name(name): def to_internal_value(self, data):
"""Converts 'first_name' to 'First name'""" raise NotImplementedError('`to_internal_value()` must be implemented.')
if not name:
return ''
return name.replace('_', ' ').capitalize()
def to_representation(self, instance):
raise NotImplementedError('`to_representation()` must be implemented.')
class RelationsList(list): def update(self, instance, attrs):
_deleted = [] raise NotImplementedError('`update()` must be implemented.')
def create(self, attrs):
raise NotImplementedError('`create()` must be implemented.')
class NestedValidationError(ValidationError): def save(self, extras=None):
""" attrs = self.validated_data
The default ValidationError behavior is to stringify each item in the list if extras is not None:
if the messages are a list of error messages. attrs = dict(list(attrs.items()) + list(extras.items()))
In the case of nested serializers, where the parent has many children, if self.instance is not None:
then the child's `serializer.errors` will be a list of dicts. In the case self.update(self.instance, attrs)
of a single child, the `serializer.errors` will be a dict. else:
self.instance = self.create(attrs)
We need to override the default behavior to get properly nested error dicts. return self.instance
"""
def __init__(self, message): def is_valid(self, raise_exception=False):
if isinstance(message, dict): if not hasattr(self, '_validated_data'):
self._messages = [message] try:
self._validated_data = self.to_internal_value(self._initial_data)
except ValidationError as exc:
self._validated_data = {}
self._errors = exc.message_dict
else: else:
self._messages = message self._errors = {}
@property if self._errors and raise_exception:
def messages(self): raise ValidationError(self._errors)
return self._messages
return not bool(self._errors)
class DictWithMetadata(dict): @property
""" def data(self):
A dict-like object, that can have additional properties attached. if not hasattr(self, '_data'):
""" if self.instance is not None:
def __getstate__(self): self._data = self.to_representation(self.instance)
""" elif self._initial_data is not None:
Used by pickle (e.g., caching). self._data = dict([
Overridden to remove the metadata from the dict, since it shouldn't be (field_name, field.get_value(self._initial_data))
pickled and may in some instances be unpickleable. for field_name, field in self.fields.items()
""" ])
return dict(self) else:
self._data = self.get_initial()
return self._data
@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)
return self._errors
class SortedDictWithMetadata(SortedDict): @property
""" def validated_data(self):
A sorted dict-like object, that can have additional properties attached. if not hasattr(self, '_validated_data'):
""" msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
def __getstate__(self): raise AssertionError(msg)
""" return self._validated_data
Used by pickle (e.g., caching).
Overriden to remove the metadata from the dict, since it shouldn't be
pickle and may in some instances be unpickleable.
"""
return SortedDict(self).__dict__
def _is_protected_type(obj): class SerializerMetaclass(type):
"""
True if the object is a native datatype that does not need to
be serialized further.
""" """
return isinstance(obj, ( This metaclass sets a dictionary named `base_fields` on the class.
types.NoneType,
int, long,
datetime.datetime, datetime.date, datetime.time,
float, Decimal,
basestring)
)
Any instances of `Field` included as attributes on either the class
def _get_declared_fields(bases, attrs): or on any of its superclasses will be include in the
`base_fields` dictionary.
""" """
Create a list of serializer field instances from the passed in 'attrs',
plus any fields on the base classes (in 'bases').
Note that all fields from the base classes are used. @classmethod
""" def _get_declared_fields(cls, bases, attrs):
fields = [(field_name, attrs.pop(field_name)) fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(six.iteritems(attrs)) for field_name, obj in list(attrs.items())
if isinstance(obj, Field)] if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1].creation_counter) fields.sort(key=lambda x: x[1]._creation_counter)
# If this class is subclassing another Serializer, add that Serializer's # If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary # fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields. # in order to maintain the correct order of fields.
for base in bases[::-1]: for base in bases[::-1]:
if hasattr(base, 'base_fields'): if hasattr(base, '_declared_fields'):
fields = list(base.base_fields.items()) + fields fields = list(base._declared_fields.items()) + fields
return SortedDict(fields) return SortedDict(fields)
class SerializerMetaclass(type):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
attrs['base_fields'] = _get_declared_fields(bases, attrs) attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
class SerializerOptions(object): @six.add_metaclass(SerializerMetaclass)
""" class Serializer(BaseSerializer):
Meta class options for Serializer def __init__(self, *args, **kwargs):
""" self.context = kwargs.pop('context', {})
def __init__(self, meta): kwargs.pop('partial', None)
self.depth = getattr(meta, 'depth', 0) kwargs.pop('many', None)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
class BaseSerializer(WritableField):
"""
This is the Serializer implementation.
We need to implement it as `BaseSerializer` due to metaclass magicks.
"""
class Meta(object):
pass
_options_class = SerializerOptions
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
context=None, partial=False, many=False,
allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
self.allow_add_remove = allow_add_remove
self.context = context or {}
self.init_data = data
self.init_files = files
self.object = instance
self.fields = self.get_fields()
self._data = None
self._files = None
self._errors = None
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
if allow_add_remove and not many:
raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
def get_default_fields(self):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
def get_fields(self):
"""
Returns the complete set of fields for the object as a dict.
This will be the set of any explicitly declared fields,
plus the set of fields returned by get_default_fields().
"""
ret = SortedDict()
# Get the explicitly declared fields
base_fields = copy.deepcopy(self.base_fields)
for key, field in base_fields.items():
ret[key] = field
# Add in the default fields
default_fields = self.get_default_fields()
for key, val in default_fields.items():
if key not in ret:
ret[key] = val
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
ret = new
# Remove anything in 'exclude'
if self.opts.exclude:
assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude:
ret.pop(key, None)
for key, field in ret.items():
field.initialize(parent=self, field_name=key)
return ret
#####
# Methods to convert or revert from objects <--> primitive representations.
def get_field_key(self, field_name):
"""
Return the key that should be used for a given field.
"""
return field_name
def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
reverted_data = {}
if data is not None and not isinstance(data, dict):
self._errors['non_field_errors'] = ['Invalid data']
return None
for field_name, field in self.fields.items(): super(Serializer, self).__init__(*args, **kwargs)
field.initialize(parent=self, field_name=field_name)
try:
field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err:
self._errors[field_name] = list(err.messages)
return reverted_data # Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
self.fields = self._get_base_fields()
def perform_validation(self, attrs): # Setup all the child fields, to provide them with the current context.
"""
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
if field_name in self._errors: field.bind(field_name, self, self)
continue
def __new__(cls, *args, **kwargs):
source = field.source or field_name # We override this method in order to automagically create
if self.partial and source not in attrs: # `ListSerializer` classes instead when `many=True` is set.
continue if kwargs.pop('many', False):
try: kwargs['child'] = cls()
validate_method = getattr(self, 'validate_%s' % field_name, None) return ListSerializer(*args, **kwargs)
if validate_method: return super(Serializer, cls).__new__(cls, *args, **kwargs)
attrs = validate_method(attrs, source)
except ValidationError as err: def _get_base_fields(self):
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) return copy.deepcopy(self._declared_fields)
# If there are already errors, we don't run .validate() because def bind(self, field_name, parent, root):
# field-validation failed and thus `attrs` may not be complete. # If the serializer is used as a field then when it becomes bound
# which in turn can cause inconsistent validation errors. # it also needs to bind all its child fields.
if not self._errors: super(Serializer, self).bind(field_name, parent, root)
try:
attrs = self.validate(attrs)
except ValidationError as err:
if hasattr(err, 'message_dict'):
for field_name, error_messages in err.message_dict.items():
self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
elif hasattr(err, 'messages'):
self._errors['non_field_errors'] = err.messages
return attrs
def validate(self, attrs):
"""
Stub method, to be overridden in Serializer subclasses
"""
return attrs
def restore_object(self, attrs, instance=None):
"""
Deserialize a dictionary of attributes into an object instance.
You should override this method to control how deserialized objects
are instantiated.
"""
if instance is not None:
instance.update(attrs)
return instance
return attrs
def to_native(self, obj):
"""
Serialize objects -> primitives.
"""
ret = self._dict_class()
ret.fields = self._dict_class()
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
if field.read_only and obj is None: field.bind(field_name, self, root)
continue
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
method = getattr(self, 'transform_%s' % field_name, None)
if callable(method):
value = method(obj, value)
if not getattr(field, 'write_only', False):
ret[key] = value
ret.fields[key] = self.augment_field(field, field_name, key, value)
return ret
def from_native(self, data, files=None):
"""
Deserialize primitives -> objects.
"""
self._errors = {}
if data is not None or files is not None: def get_initial(self):
attrs = self.restore_fields(data, files) return dict([
if attrs is not None: (field.field_name, field.get_initial())
attrs = self.perform_validation(attrs) for field in self.fields.values()
else: ])
self._errors['non_field_errors'] = ['No input provided']
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'object', None))
def augment_field(self, field, field_name, key, value): def get_value(self, dictionary):
# This horrible stuff is to manage serializers rendering to HTML # We override the default field access in order to support
field._errors = self._errors.get(key) if self._errors else None # nested HTML forms.
field._name = field_name if html.is_html_input(dictionary):
field._value = self.init_data.get(key) if self._errors and self.init_data else value return html.parse_html_dict(dictionary, prefix=self.field_name)
if not field.label: return dictionary.get(self.field_name, empty)
field.label = pretty_name(key)
return field
def field_to_native(self, obj, field_name): def to_internal_value(self, data):
""" """
Override default so that the serializer can be used as a nested field Dict of native values <- Dict of primitive datatypes.
across relationships.
""" """
if self.write_only: if not isinstance(data, dict):
return None raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
})
if self.source == '*': ret = {}
return self.to_native(obj) errors = {}
fields = [field for field in self.fields.values() if not field.read_only]
# Get the raw field value for field in fields:
validate_method = getattr(self, 'validate_' + field.field_name, None)
primitive_value = field.get_value(data)
try: try:
source = self.source or field_name validated_value = field.run_validation(primitive_value)
value = obj if validate_method is not None:
validated_value = validate_method(validated_value)
for component in source.split('.'): except ValidationError as exc:
if value is None: errors[field.field_name] = exc.messages
break except SkipField:
value = get_component(value, component) pass
except ObjectDoesNotExist:
return None
if is_simple_callable(getattr(value, 'all', None)):
return [self.to_native(item) for item in value.all()]
if value is None:
return None
if self.many:
return [self.to_native(item) for item in value]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
"""
Override default so that the serializer can be used as a writable
nested field across relationships.
"""
if self.read_only:
return
try:
value = data[field_name]
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
value = copy.deepcopy(self.default)
else:
if self.required:
raise ValidationError(self.error_messages['required'])
return
if self.source == '*':
if value:
reverted_data = self.restore_fields(value, {})
if not self._errors:
into.update(reverted_data)
else:
if value in (None, ''):
into[(self.source or field_name)] = None
else:
# Set the serializer object if it exists
obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
# If we have a model manager or similar object then we need
# to iterate through each instance.
if (
self.many and
not hasattr(obj, '__iter__') and
is_simple_callable(getattr(obj, 'all', None))
):
obj = obj.all()
kwargs = {
'instance': obj,
'data': value,
'context': self.context,
'partial': self.partial,
'many': self.many,
'allow_add_remove': self.allow_add_remove
}
serializer = self.__class__(**kwargs)
if serializer.is_valid():
into[self.source or field_name] = serializer.object
else: else:
# Propagate errors up to our parent set_value(ret, field.source_attrs, validated_value)
raise NestedValidationError(serializer.errors)
def get_identity(self, data): if errors:
""" raise ValidationError(errors)
This hook is required for bulk update.
It is used to determine the canonical identity of a given object.
Note that the data has not been validated at this point, so we need
to make sure that we catch any cases of incorrect datatypes being
passed to this method.
"""
try: try:
return data.get('id', None) return self.validate(ret)
except AttributeError: except ValidationError as exc:
return None raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
})
@property def to_representation(self, instance):
def errors(self):
""" """
Run deserialization and return error data, Object instance -> Dict of primitive datatypes.
setting self.object if no errors occurred.
""" """
if self._errors is None: ret = SortedDict()
data, files = self.init_data, self.init_files fields = [field for field in self.fields.values() if not field.write_only]
if self.many is not None: for field in fields:
many = self.many native_value = field.get_attribute(instance)
else: ret[field.field_name] = field.to_representation(native_value)
many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
if many:
warnings.warn('Implicit list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
DeprecationWarning, stacklevel=3)
if many:
ret = RelationsList()
errors = []
update = self.object is not None
if update:
# If this is a bulk update we need to map all the objects
# to a canonical identity so we can determine which
# individual object is being updated for each item in the
# incoming data
objects = self.object
identities = [self.get_identity(self.to_native(obj)) for obj in objects]
identity_to_objects = dict(zip(identities, objects))
if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
for item in data:
if update:
# Determine which object we're updating
identity = self.get_identity(item)
self.object = identity_to_objects.pop(identity, None)
if self.object is None and not self.allow_add_remove:
ret.append(None)
errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
continue
ret.append(self.from_native(item, None)) return ret
errors.append(self._errors)
if update and self.allow_add_remove: def validate(self, attrs):
ret._deleted = identity_to_objects.values() return attrs
self._errors = any(errors) and errors or [] def __iter__(self):
else: errors = self.errors if hasattr(self, '_errors') else {}
self._errors = {'non_field_errors': ['Expected a list of items.']} for field in self.fields.values():
else: value = self.data.get(field.field_name) if self.data else None
ret = self.from_native(data, files) error = errors.get(field.field_name)
yield FieldResult(field, value, error)
if not self._errors: def __repr__(self):
self.object = ret return representation.serializer_repr(self, indent=1)
return self._errors
def is_valid(self): class ListSerializer(BaseSerializer):
return not self.errors child = None
initial = []
@property def __init__(self, *args, **kwargs):
def data(self): self.child = kwargs.pop('child', copy.deepcopy(self.child))
""" assert self.child is not None, '`child` is a required argument.'
Returns the serialized data on the serializer. self.context = kwargs.pop('context', {})
""" kwargs.pop('partial', None)
if self._data is None:
obj = self.object
if self.many is not None: super(ListSerializer, self).__init__(*args, **kwargs)
many = self.many self.child.bind('', self, self)
else:
many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
if many:
warnings.warn('Implicit list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
DeprecationWarning, stacklevel=2)
if many:
self._data = [self.to_native(item) for item in obj]
else:
self._data = self.to_native(obj)
return self._data
def save_object(self, obj, **kwargs): def bind(self, field_name, parent, root):
obj.save(**kwargs) # If the list is used as a field then it needs to provide
# the current context to the child serializer.
super(ListSerializer, self).bind(field_name, parent, root)
self.child.bind(field_name, self, root)
def delete_object(self, obj): def get_value(self, dictionary):
obj.delete() # We override the default field access in order to support
# lists in HTML forms.
if is_html_input(dictionary):
return html.parse_html_list(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
def save(self, **kwargs): def to_internal_value(self, data):
""" """
Save the deserialized object and return it. List of dicts of native values <- List of dicts of primitive datatypes.
""" """
# Clear cached _data, which may be invalidated by `save()` if html.is_html_input(data):
self._data = None data = html.parse_html_list(data)
if isinstance(self.object, list): return [self.child.run_validation(item) for item in data]
[self.save_object(item, **kwargs) for item in self.object]
if self.object._deleted: def to_representation(self, data):
[self.delete_object(item) for item in self.object._deleted]
else:
self.save_object(self.object, **kwargs)
return self.object
def metadata(self):
""" """
Return a dictionary of metadata about the fields on the serializer. List of object instances -> List of dicts of primitive datatypes.
Useful for things like responding to OPTIONS requests, or generating
API schemas for auto-documentation.
""" """
return SortedDict( return [self.child.to_representation(item) for item in data]
[
(field_name, field.metadata())
for field_name, field in six.iteritems(self.fields)
]
)
def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list]
class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): def save(self):
pass if self.instance is not None:
self.update(self.instance, self.validated_data)
self.instance = self.create(self.validated_data)
return self.instance
def __repr__(self):
class ModelSerializerOptions(SerializerOptions): return representation.list_repr(self, indent=1)
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
super(ModelSerializerOptions, self).__init__(meta)
self.model = getattr(meta, 'model', None)
self.read_only_fields = getattr(meta, 'read_only_fields', ())
self.write_only_fields = getattr(meta, 'write_only_fields', ())
class ModelSerializer(Serializer): class ModelSerializer(Serializer):
""" _field_mapping = {
A serializer that deals with model instances and querysets.
"""
_options_class = ModelSerializerOptions
field_mapping = {
models.AutoField: IntegerField, models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.DateField: DateField,
models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.Field: ModelField,
models.FileField: FileField,
models.FloatField: FloatField, models.FloatField: FloatField,
models.ImageField: ImageField,
models.IntegerField: IntegerField, models.IntegerField: IntegerField,
models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField, models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField,
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.CharField: CharField,
models.URLField: URLField,
models.SlugField: SlugField, models.SlugField: SlugField,
models.SmallIntegerField: IntegerField,
models.TextField: CharField, models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField, models.TimeField: TimeField,
models.BooleanField: BooleanField, models.URLField: URLField,
models.NullBooleanField: BooleanField,
models.FileField: FileField,
models.ImageField: ImageField,
}
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
cls = self.opts.model
assert cls is not None, (
"Serializer class '%s' is missing 'model' Meta option" %
self.__class__.__name__
)
opts = cls._meta.concrete_model._meta
ret = SortedDict()
nested = bool(self.opts.depth)
# Deal with adding the primary key field
pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
serializer_pk_field = self.get_pk_field(pk_field)
if serializer_pk_field:
ret[pk_field.name] = serializer_pk_field
# Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize]
forward_rels += [field for field in opts.many_to_many if field.serialize]
for model_field in forward_rels:
has_through_model = False
if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
related_model = _resolve_model(model_field.rel.to)
if to_many and not model_field.rel.through._meta.auto_created:
has_through_model = True
if model_field.rel and nested:
if len(inspect.getargspec(self.get_nested_field).args) == 2:
warnings.warn(
'The `get_nested_field(model_field)` call signature '
'is deprecated. '
'Use `get_nested_field(model_field, related_model, '
'to_many) instead',
DeprecationWarning
)
field = self.get_nested_field(model_field)
else:
field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
if len(inspect.getargspec(self.get_nested_field).args) == 3:
warnings.warn(
'The `get_related_field(model_field, to_many)` call '
'signature is deprecated. '
'Use `get_related_field(model_field, related_model, '
'to_many) instead',
DeprecationWarning
)
field = self.get_related_field(model_field, to_many=to_many)
else:
field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
if field:
if has_through_model:
field.read_only = True
ret[model_field.name] = field
# Deal with reverse relationships
if not self.opts.fields:
reverse_rels = []
else:
# Reverse relationships are only included if they are explicitly
# present in the `fields` option on the serializer
reverse_rels = opts.get_all_related_objects()
reverse_rels += opts.get_all_related_many_to_many_objects()
for relation in reverse_rels:
accessor_name = relation.get_accessor_name()
if not self.opts.fields or accessor_name not in self.opts.fields:
continue
related_model = relation.model
to_many = relation.field.rel.multiple
has_through_model = False
is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField)
if (
is_m2m and
hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created
):
has_through_model = True
if nested:
field = self.get_nested_field(None, related_model, to_many)
else:
field = self.get_related_field(None, related_model, to_many)
if field:
if has_through_model:
field.read_only = True
ret[accessor_name] = field
# Ensure that 'read_only_fields' is an iterable
assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
# Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
"`read_only_fields`, but also added "
"as an explicit field. Remove it from `read_only_fields`." %
(field_name, self.__class__.__name__))
assert field_name in ret, (
"Non-existant field '%s' specified in `read_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
# Ensure that 'write_only_fields' is an iterable
assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
"`write_only_fields`, but also added "
"as an explicit field. Remove it from `write_only_fields`." %
(field_name, self.__class__.__name__))
assert field_name in ret, (
"Non-existant field '%s' specified in `write_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].write_only = True
return ret
def get_pk_field(self, model_field):
"""
Returns a default instance of the pk field.
"""
return self.get_field(model_field)
def get_nested_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a nested relational field.
Note that model_field will be `None` for reverse relationships.
"""
class NestedModelSerializer(ModelSerializer):
class Meta:
model = related_model
depth = self.opts.depth - 1
return NestedModelSerializer(many=to_many)
def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
Note that model_field will be `None` for reverse relationships.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = {
'queryset': related_model._default_manager,
'many': to_many
}
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
if model_field.null or model_field.blank:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
if model_field.has_default():
kwargs['default'] = model_field.get_default()
if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
if model_field.null:
kwargs['empty'] = None
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
if issubclass(model_field.__class__, models.PositiveIntegerField) or\
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
attribute_dict = {
models.CharField: ['max_length'],
models.CommaSeparatedIntegerField: ['max_length'],
models.DecimalField: ['max_digits', 'decimal_places'],
models.EmailField: ['max_length'],
models.FileField: ['max_length'],
models.ImageField: ['max_length'],
models.SlugField: ['max_length'],
models.URLField: ['max_length'],
} }
_related_class = PrimaryKeyRelatedField
if model_field.__class__ in attribute_dict: def create(self, attrs):
attributes = attribute_dict[model_field.__class__] ModelClass = self.Meta.model
for attribute in attributes:
kwargs.update({attribute: getattr(model_field, attribute)})
try: # Remove many-to-many relationships from attrs.
return self.field_mapping[model_field.__class__](**kwargs) # They are not valid arguments to the default `.create()` method,
except KeyError: # as they require that the instance has already been saved.
return ModelField(model_field=model_field, **kwargs) info = model_meta.get_field_info(ModelClass)
many_to_many = {}
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in attrs):
many_to_many[field_name] = attrs.pop(field_name)
def get_validation_exclusions(self, instance=None): instance = ModelClass.objects.create(**attrs)
"""
Return a list of field names to exclude from model validation.
"""
cls = self.opts.model
opts = cls._meta.concrete_model._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items(): # Save many-to-many relationships after the instance is created.
field_name = field.source or field_name if many_to_many:
if ( for field_name, value in many_to_many.items():
field_name in exclusions setattr(instance, field_name, value)
and not field.read_only
and (field.required or hasattr(instance, field_name))
and not isinstance(field, Serializer)
):
exclusions.remove(field_name)
return exclusions
def full_clean(self, instance):
"""
Perform Django's full_clean, and populate the `errors` dictionary
if any validation errors occur.
Note that we don't perform this inside the `.restore_object()` method,
so that subclasses can override `.restore_object()`, and still get
the full_clean validation checking.
"""
try:
instance.full_clean(exclude=self.get_validation_exclusions(instance))
except ValidationError as err:
self._errors = err.message_dict
return None
return instance return instance
def restore_object(self, attrs, instance=None): def update(self, obj, attrs):
""" for attr, value in attrs.items():
Restore the model instance. setattr(obj, attr, value)
""" obj.save()
m2m_data = {}
related_data = {}
nested_forward_relations = {}
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
for (obj, model) in meta.get_all_related_objects_with_model():
field_name = obj.get_accessor_name()
if field_name in attrs:
related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
for (obj, model) in meta.get_all_related_m2m_objects_with_model():
field_name = obj.get_accessor_name()
if field_name in attrs:
m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations
for field in meta.many_to_many + meta.virtual_fields:
if isinstance(field, GenericForeignKey):
continue
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
# Nested forward relations - These need to be marked so we can save
# them before saving the parent model instance.
for field_name in attrs.keys():
if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name]
# Create an empty instance of the model
if instance is None:
instance = self.opts.model()
for key, val in attrs.items(): def _get_base_fields(self):
try: declared_fields = copy.deepcopy(self._declared_fields)
setattr(instance, key, val)
except ValueError:
self._errors[key] = [self.error_messages['required']]
# Any relations that cannot be set until we've
# saved the model get hidden away on these
# private attributes, so we can deal with them
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations
return instance
def from_native(self, data, files): ret = SortedDict()
""" model = getattr(self.Meta, 'model')
Override the default method to also include model field validation. fields = getattr(self.Meta, 'fields', None)
""" depth = getattr(self.Meta, 'depth', 0)
instance = super(ModelSerializer, self).from_native(data, files) extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
if not self._errors:
return self.full_clean(instance) # Retrieve metadata about fields & relationships on the model class.
info = model_meta.get_field_info(model)
# Use the default set of fields if none is supplied explicitly.
if fields is None:
fields = self._get_default_field_names(declared_fields, info)
for field_name in fields:
if field_name in declared_fields:
# Field is explicitly declared on the class, use that.
ret[field_name] = declared_fields[field_name]
continue
def save_object(self, obj, **kwargs): elif field_name == api_settings.URL_FIELD_NAME:
""" # Create the URL field.
Save the deserialized object. field_cls = HyperlinkedIdentityField
""" kwargs = get_url_kwargs(model)
if getattr(obj, '_nested_forward_relations', None):
# Nested relationships need to be saved before we can save the elif field_name in info.fields_and_pk:
# parent instance. # Create regular model fields.
for field_name, sub_object in obj._nested_forward_relations.items(): model_field = info.fields_and_pk[field_name]
if sub_object: field_cls = lookup_class(self._field_mapping, model_field)
self.save_object(sub_object) kwargs = get_field_kwargs(field_name, model_field)
setattr(obj, field_name, sub_object) if 'choices' in kwargs:
# Fields with choices get coerced into `ChoiceField`
obj.save(**kwargs) # instead of using their regular typed field.
field_cls = ChoiceField
if getattr(obj, '_m2m_data', None): if not issubclass(field_cls, ModelField):
for accessor_name, object_list in obj._m2m_data.items(): # `model_field` is only valid for the fallback case of
setattr(obj, accessor_name, object_list) # `ModelField`, which is used when no other typed field
del(obj._m2m_data) # matched to the model field.
kwargs.pop('model_field', None)
if getattr(obj, '_related_data', None):
related_fields = dict([ elif field_name in info.relations:
(field.get_accessor_name(), field) # Create forward and reverse relationships.
for field, model relation_info = info.relations[field_name]
in obj._meta.get_all_related_objects_with_model() if depth:
]) field_cls = self._get_nested_class(depth, relation_info)
for accessor_name, related in obj._related_data.items(): kwargs = get_nested_relation_kwargs(relation_info)
if isinstance(related, RelationsList):
# Nested reverse fk relationship
for related_item in related:
fk_field = related_fields[accessor_name].field.name
setattr(related_item, fk_field, obj)
self.save_object(related_item)
# Delete any removed objects
if related._deleted:
[self.delete_object(item) for item in related._deleted]
elif isinstance(related, models.Model):
# Nested reverse one-one relationship
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
setattr(related, fk_field, obj)
self.save_object(related)
else: else:
# Reverse FK or reverse one-one field_cls = self._related_class
setattr(obj, accessor_name, related) kwargs = get_relation_kwargs(field_name, relation_info)
del(obj._related_data) # `view_name` is only valid for hyperlinked relationships.
if not issubclass(field_cls, HyperlinkedRelatedField):
kwargs.pop('view_name', None)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
""" elif hasattr(model, field_name):
Options for HyperlinkedModelSerializer # Create a read only field for model methods and properties.
""" field_cls = ReadOnlyField
def __init__(self, meta): kwargs = {}
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
class HyperlinkedModelSerializer(ModelSerializer):
"""
A subclass of ModelSerializer that uses hyperlinked relationships,
instead of primary key relationships.
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None: else:
self.opts.view_name = self._get_default_view_name(self.opts.model) raise ImproperlyConfigured(
'Field name `%s` is not valid for model `%s`.' %
(field_name, model.__class__.__name__)
)
if self.opts.url_field_name not in fields: # Check that any fields declared on the class are
url_field = self._hyperlink_identify_field_class( # also explicity included in `Meta.fields`.
view_name=self.opts.view_name, missing_fields = set(declared_fields.keys()) - set(fields)
lookup_field=self.opts.lookup_field if missing_fields:
missing_field = list(missing_fields)[0]
raise ImproperlyConfigured(
'Field `%s` has been declared on serializer `%s`, but '
'is missing from `Meta.fields`.' %
(missing_field, self.__class__.__name__)
) )
ret = self._dict_class()
ret[self.opts.url_field_name] = url_field
ret.update(fields)
fields = ret
return fields # Populate any kwargs defined in `Meta.extra_kwargs`
kwargs.update(extra_kwargs.get(field_name, {}))
def get_pk_field(self, model_field): # Create the serializer field.
if self.opts.fields and model_field.name in self.opts.fields: ret[field_name] = field_cls(**kwargs)
return self.get_field(model_field)
def get_related_field(self, model_field, related_model, to_many): return ret
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = {
'queryset': related_model._default_manager,
'view_name': self._get_default_view_name(related_model),
'many': to_many
}
if model_field: def _get_default_field_names(self, declared_fields, model_info):
kwargs['required'] = not(model_field.null or model_field.blank) return (
if model_field.help_text is not None: [model_info.pk.name] +
kwargs['help_text'] = model_field.help_text list(declared_fields.keys()) +
if model_field.verbose_name is not None: list(model_info.fields.keys()) +
kwargs['label'] = model_field.verbose_name list(model_info.forward_relations.keys())
)
if self.opts.lookup_field: def _get_nested_class(self, nested_depth, relation_info):
kwargs['lookup_field'] = self.opts.lookup_field class NestedSerializer(ModelSerializer):
class Meta:
model = relation_info.related
depth = nested_depth
return NestedSerializer
return self._hyperlink_field_class(**kwargs)
def get_identity(self, data): class HyperlinkedModelSerializer(ModelSerializer):
""" _related_class = HyperlinkedRelatedField
This hook is required for bulk update.
We need to override the default, to use the url as the identity. def _get_default_field_names(self, declared_fields, model_info):
""" return (
try: [api_settings.URL_FIELD_NAME] +
return data.get(self.opts.url_field_name, None) list(declared_fields.keys()) +
except AttributeError: list(model_info.fields.keys()) +
return None list(model_info.forward_relations.keys())
)
def _get_default_view_name(self, model): def _get_nested_class(self, nested_depth, relation_info):
""" class NestedSerializer(HyperlinkedModelSerializer):
Return the view name to use if 'view_name' is not specified in 'Meta' class Meta:
""" model = relation_info.related
model_meta = model._meta depth = nested_depth
format_kwargs = { return NestedSerializer
'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower()
}
return self._default_view_name % format_kwargs
...@@ -77,6 +77,7 @@ DEFAULTS = { ...@@ -77,6 +77,7 @@ DEFAULTS = {
# Exception handling # Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
'NON_FIELD_ERRORS_KEY': 'non_field_errors',
# Testing # Testing
'TEST_REQUEST_RENDERER_CLASSES': ( 'TEST_REQUEST_RENDERER_CLASSES': (
...@@ -96,24 +97,19 @@ DEFAULTS = { ...@@ -96,24 +97,19 @@ DEFAULTS = {
'URL_FIELD_NAME': 'url', 'URL_FIELD_NAME': 'url',
# Input and output formats # Input and output formats
'DATE_INPUT_FORMATS': ( 'DATE_FORMAT': ISO_8601,
ISO_8601, 'DATE_INPUT_FORMATS': (ISO_8601,),
),
'DATE_FORMAT': None,
'DATETIME_INPUT_FORMATS': ( 'DATETIME_FORMAT': ISO_8601,
ISO_8601, 'DATETIME_INPUT_FORMATS': (ISO_8601,),
),
'DATETIME_FORMAT': None,
'TIME_INPUT_FORMATS': ( 'TIME_FORMAT': ISO_8601,
ISO_8601, 'TIME_INPUT_FORMATS': (ISO_8601,),
),
'TIME_FORMAT': None,
# Pending deprecation
'FILTER_BACKEND': None,
# Encoding
'UNICODE_JSON': True,
'COMPACT_JSON': True,
'COERCE_DECIMAL_TO_STRING': True
} }
...@@ -129,7 +125,6 @@ IMPORT_STRINGS = ( ...@@ -129,7 +125,6 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS', 'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER', 'EXCEPTION_HANDLER',
'FILTER_BACKEND',
'TEST_REQUEST_RENDERER_CLASSES', 'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN', 'UNAUTHENTICATED_TOKEN',
...@@ -196,15 +191,9 @@ class APISettings(object): ...@@ -196,15 +191,9 @@ class APISettings(object):
if val and attr in self.import_strings: if val and attr in self.import_strings:
val = perform_import(val, attr) val = perform_import(val, attr)
self.validate_setting(attr, val)
# Cache the result # Cache the result
setattr(self, attr, val) setattr(self, attr, val)
return val return val
def validate_setting(self, attr, val):
if attr == 'FILTER_BACKEND' and val is not None:
# Make sure we can initialize the class
val()
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
...@@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory): ...@@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory):
Encode the data returning a two tuple of (bytes, content_type) Encode the data returning a two tuple of (bytes, content_type)
""" """
if not data: if data is None:
return ('', content_type) return ('', content_type)
assert format is None or content_type is None, ( assert format is None or content_type is None, (
......
...@@ -7,7 +7,6 @@ from django.db.models.query import QuerySet ...@@ -7,7 +7,6 @@ from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.functional import Promise from django.utils.functional import Promise
from rest_framework.compat import force_text from rest_framework.compat import force_text
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime import datetime
import decimal import decimal
import types import types
...@@ -17,45 +16,47 @@ import json ...@@ -17,45 +16,47 @@ import json
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):
""" """
JSONEncoder subclass that knows how to encode date/time/timedelta, JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, and generators. decimal types, generators and other basic python objects.
""" """
def default(self, o): def default(self, obj):
# For Date Time string spec, see ECMA 262 # For Date Time string spec, see ECMA 262
# http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 # http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
if isinstance(o, Promise): if isinstance(obj, Promise):
return force_text(o) return force_text(obj)
elif isinstance(o, datetime.datetime): elif isinstance(obj, datetime.datetime):
r = o.isoformat() representation = obj.isoformat()
if o.microsecond: if obj.microsecond:
r = r[:23] + r[26:] representation = representation[:23] + representation[26:]
if r.endswith('+00:00'): if representation.endswith('+00:00'):
r = r[:-6] + 'Z' representation = representation[:-6] + 'Z'
return r return representation
elif isinstance(o, datetime.date): elif isinstance(obj, datetime.date):
return o.isoformat() return obj.isoformat()
elif isinstance(o, datetime.time): elif isinstance(obj, datetime.time):
if timezone and timezone.is_aware(o): if timezone and timezone.is_aware(obj):
raise ValueError("JSON can't represent timezone-aware times.") raise ValueError("JSON can't represent timezone-aware times.")
r = o.isoformat() representation = obj.isoformat()
if o.microsecond: if obj.microsecond:
r = r[:12] representation = representation[:12]
return r return representation
elif isinstance(o, datetime.timedelta): elif isinstance(obj, datetime.timedelta):
return str(o.total_seconds()) return str(obj.total_seconds())
elif isinstance(o, decimal.Decimal): elif isinstance(obj, decimal.Decimal):
return str(o) # Serializers will coerce decimals to strings by default.
elif isinstance(o, QuerySet): return float(obj)
return list(o) elif isinstance(obj, QuerySet):
elif hasattr(o, 'tolist'): return list(obj)
return o.tolist() elif hasattr(obj, 'tolist'):
elif hasattr(o, '__getitem__'): # Numpy arrays and array scalars.
return obj.tolist()
elif hasattr(obj, '__getitem__'):
try: try:
return dict(o) return dict(obj)
except: except:
pass pass
elif hasattr(o, '__iter__'): elif hasattr(obj, '__iter__'):
return [i for i in o] return [item for item in obj]
return super(JSONEncoder, self).default(o) return super(JSONEncoder, self).default(obj)
try: try:
...@@ -106,14 +107,14 @@ else: ...@@ -106,14 +107,14 @@ else:
SortedDict, SortedDict,
yaml.representer.SafeRepresenter.represent_dict yaml.representer.SafeRepresenter.represent_dict
) )
SafeDumper.add_representer( # SafeDumper.add_representer(
DictWithMetadata, # DictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict # yaml.representer.SafeRepresenter.represent_dict
) # )
SafeDumper.add_representer( # SafeDumper.add_representer(
SortedDictWithMetadata, # SortedDictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict # yaml.representer.SafeRepresenter.represent_dict
) # )
SafeDumper.add_representer( SafeDumper.add_representer(
types.GeneratorType, types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list yaml.representer.SafeRepresenter.represent_list
......
"""
Helper functions for mapping model fields to a dictionary of default
keyword arguments that should be used for their equivelent serializer fields.
"""
from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import clean_manytomany_helptext
import inspect
def lookup_class(mapping, instance):
"""
Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
for cls in inspect.getmro(instance.__class__):
if cls in mapping:
return mapping[cls]
raise KeyError('Class %s not found in lookup.', cls.__name__)
def needs_label(model_field, field_name):
"""
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
default_label = field_name.replace('_', ' ').capitalize()
return capfirst(model_field.verbose_name) != default_label
def get_detail_view_name(model):
"""
Given a model class, return the view name to use for URL relationships
that refer to instances of the model.
"""
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}
def get_field_kwargs(field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.flatchoices:
# If this model field contains choices, then return now,
# any further keyword arguments are not valid.
kwargs['choices'] = model_field.flatchoices
return kwargs
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.BooleanField):
# models.BooleanField has `blank=True`, but *is* actually
# required *unless* a default is provided.
# Also note that Django<1.6 uses `default=False` for
# models.BooleanField, but Django>=1.6 uses `default=None`.
kwargs.pop('required', None)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
# The following will only be used by ModelField classes.
# Gets removed for everything else.
kwargs['model_field'] = model_field
return kwargs
def get_relation_kwargs(field_name, relation_info):
"""
Creates a default instance of a flat relational field.
"""
model_field, related_model, to_many, has_through_model = relation_info
kwargs = {
'queryset': related_model._default_manager,
'view_name': get_detail_view_name(related_model)
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
return kwargs
def get_nested_relation_kwargs(relation_info):
kwargs = {'read_only': True}
if relation_info.to_many:
kwargs['many'] = True
return kwargs
def get_url_kwargs(model_field):
return {
'view_name': get_detail_view_name(model_field)
}
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
Utility functions to return a formatted name and description for a given view. Utility functions to return a formatted name and description for a given view.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import re
from django.utils.html import escape from django.utils.html import escape
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from rest_framework.compat import apply_markdown
import re from rest_framework.compat import apply_markdown, force_text
def remove_trailing_string(content, trailing): def remove_trailing_string(content, trailing):
...@@ -28,6 +29,7 @@ def dedent(content): ...@@ -28,6 +29,7 @@ def dedent(content):
as it fails to dedent multiline docstrings that include as it fails to dedent multiline docstrings that include
unindented text on the initial line. unindented text on the initial line.
""" """
content = force_text(content)
whitespace_counts = [len(line) - len(line.lstrip(' ')) whitespace_counts = [len(line) - len(line.lstrip(' '))
for line in content.splitlines()[1:] if line.lstrip()] for line in content.splitlines()[1:] if line.lstrip()]
......
"""
Helpers for dealing with HTML input.
"""
import re
def is_html_input(dictionary):
# MultiDict type datastructures are used to represent HTML form input,
# which may have more than one value for each key.
return hasattr(dictionary, 'getlist')
def parse_html_list(dictionary, prefix=''):
"""
Used to suport list values in HTML forms.
Supports lists of primitives and/or dictionaries.
* List of primitives.
{
'[0]': 'abc',
'[1]': 'def',
'[2]': 'hij'
}
-->
[
'abc',
'def',
'hij'
]
* List of dictionaries.
{
'[0]foo': 'abc',
'[0]bar': 'def',
'[1]foo': 'hij',
'[2]bar': 'klm',
}
-->
[
{'foo': 'abc', 'bar': 'def'},
{'foo': 'hij', 'bar': 'klm'}
]
"""
Dict = type(dictionary)
ret = {}
regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
for field, value in dictionary.items():
match = regex.match(field)
if not match:
continue
index, key = match.groups()
index = int(index)
if not key:
ret[index] = value
elif isinstance(ret.get(index), dict):
ret[index][key] = value
else:
ret[index] = Dict({key: value})
return [ret[item] for item in sorted(ret.keys())]
def parse_html_dict(dictionary, prefix):
"""
Used to support dictionary values in HTML forms.
{
'profile.username': 'example',
'profile.email': 'example@example.com',
}
-->
{
'profile': {
'username': 'example,
'email': 'example@example.com'
}
}
"""
ret = {}
regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
for field, value in dictionary.items():
match = regex.match(field)
if not match:
continue
key = match.groups()[0]
ret[key] = value
return ret
"""
Helper functions that convert strftime formats into more readable representations.
"""
from rest_framework import ISO_8601
def datetime_formats(formats):
format = ', '.join(formats).replace(
ISO_8601,
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
)
return humanize_strptime(format)
def date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
return humanize_strptime(format)
def time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
return humanize_strptime(format)
def humanize_strptime(format_string):
# Note that we're missing some of the locale specific mappings that
# don't really make sense.
mapping = {
"%Y": "YYYY",
"%y": "YY",
"%m": "MM",
"%b": "[Jan-Dec]",
"%B": "[January-December]",
"%d": "DD",
"%H": "hh",
"%I": "hh", # Requires '%p' to differentiate from '%H'.
"%M": "mm",
"%S": "ss",
"%f": "uuuuuu",
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]"
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)
return format_string
"""
Helper function for returning the field information that is associated
with a model class. This includes returning all the forward and reverse
relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import namedtuple
from django.db import models
from django.utils import six
from django.utils.datastructures import SortedDict
import inspect
FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance
'fields', # Dict of field name -> model field instance
'forward_relations', # Dict of field name -> RelationInfo
'reverse_relations', # Dict of field name -> RelationInfo
'fields_and_pk', # Shortcut for 'pk' + 'fields'
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
])
RelationInfo = namedtuple('RelationInfo', [
'model_field',
'related',
'to_many',
'has_through_model'
])
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
raise ValueError("{0} is not a Django model".format(obj))
def get_field_info(model):
"""
Given a model class, returns a `FieldInfo` instance containing metadata
about the various field types on the model.
"""
opts = model._meta.concrete_model._meta
# Deal with the primary key.
pk = opts.pk
while pk.rel and pk.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk.
pk = pk.rel.to._meta.pk
# Deal with regular fields.
fields = SortedDict()
for field in [field for field in opts.fields if field.serialize and not field.rel]:
fields[field.name] = field
# Deal with forward relationships.
forward_relations = SortedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
model_field=field,
related=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
)
# Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
model_field=field,
related=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
not field.rel.through._meta.auto_created
)
)
# Deal with reverse relationships.
reverse_relations = SortedDict()
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
related=relation.model,
to_many=relation.field.rel.multiple,
has_through_model=False
)
# Deal with reverse many-to-many relationships.
for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
related=relation.model,
to_many=True,
has_through_model=(
hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created
)
)
# Shortcut that merges both regular fields and the pk,
# for simplifying regular field lookup.
fields_and_pk = SortedDict()
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
# Shortcut that merges both forward and reverse relationships
relations = SortedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
)
return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
"""
Helper functions for creating user-friendly representations
of serializer classes and serializer fields.
"""
from django.db import models
import re
def manager_repr(value):
model = value.model
opts = model._meta
for _, name, manager in opts.concrete_managers + opts.abstract_managers:
if manager == value:
return '%s.%s.all()' % (model._meta.object_name, name)
return repr(value)
def smart_repr(value):
if isinstance(value, models.Manager):
return manager_repr(value)
value = repr(value)
# Representations like u'help text'
# should simply be presented as 'help text'
if value.startswith("u'") and value.endswith("'"):
return value[1:]
# Representations like
# <django.core.validators.RegexValidator object at 0x1047af050>
# Should be presented as
# <django.core.validators.RegexValidator object>
value = re.sub(' at 0x[0-9a-f]{4,32}>', '>', value)
return value
def field_repr(field, force_many=False):
kwargs = field._kwargs
if force_many:
kwargs = kwargs.copy()
kwargs['many'] = True
kwargs.pop('child', None)
arg_string = ', '.join([smart_repr(val) for val in field._args])
kwarg_string = ', '.join([
'%s=%s' % (key, smart_repr(val))
for key, val in sorted(kwargs.items())
])
if arg_string and kwarg_string:
arg_string += ', '
if force_many:
class_name = force_many.__class__.__name__
else:
class_name = field.__class__.__name__
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
def serializer_repr(serializer, indent, force_many=None):
ret = field_repr(serializer, force_many) + ':'
indent_str = ' ' * indent
if force_many:
fields = force_many.fields
else:
fields = serializer.fields
for field_name, field in fields.items():
ret += '\n' + indent_str + field_name + ' = '
if hasattr(field, 'fields'):
ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'):
ret += list_repr(field, indent + 1)
elif hasattr(field, 'child_relation'):
ret += field_repr(field.child_relation, force_many=field.child_relation)
else:
ret += field_repr(field)
return ret
def list_repr(serializer, indent):
child = serializer.child
if hasattr(child, 'fields'):
return serializer_repr(serializer, indent, force_many=child)
return field_repr(serializer)
...@@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework. ...@@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS
from django.http import Http404 from django.http import Http404
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
...@@ -51,7 +51,8 @@ def exception_handler(exc): ...@@ -51,7 +51,8 @@ def exception_handler(exc):
Returns the response that should be used for any given exception. Returns the response that should be used for any given exception.
By default we handle the REST framework `APIException`, and also By default we handle the REST framework `APIException`, and also
Django's builtin `Http404` and `PermissionDenied` exceptions. Django's built-in `ValidationError`, `Http404` and `PermissionDenied`
exceptions.
Any unhandled exceptions may return `None`, which will cause a 500 error Any unhandled exceptions may return `None`, which will cause a 500 error
to be raised. to be raised.
...@@ -61,13 +62,22 @@ def exception_handler(exc): ...@@ -61,13 +62,22 @@ def exception_handler(exc):
if getattr(exc, 'auth_header', None): if getattr(exc, 'auth_header', None):
headers['WWW-Authenticate'] = exc.auth_header headers['WWW-Authenticate'] = exc.auth_header
if getattr(exc, 'wait', None): if getattr(exc, 'wait', None):
headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
headers['Retry-After'] = '%d' % exc.wait headers['Retry-After'] = '%d' % exc.wait
return Response({'detail': exc.detail}, return Response({'detail': exc.detail},
status=exc.status_code, status=exc.status_code,
headers=headers) headers=headers)
elif isinstance(exc, ValidationError):
# ValidationErrors may include the non-field key named '__all__'.
# When returning a response we map this to a key name that can be
# modified in settings.
if NON_FIELD_ERRORS in exc.message_dict:
errors = exc.message_dict.pop(NON_FIELD_ERRORS)
exc.message_dict[api_settings.NON_FIELD_ERRORS_KEY] = errors
return Response(exc.message_dict,
status=status.HTTP_400_BAD_REQUEST)
elif isinstance(exc, Http404): elif isinstance(exc, Http404):
return Response({'detail': 'Not found'}, return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND) status=status.HTTP_404_NOT_FOUND)
......
...@@ -20,6 +20,7 @@ from __future__ import unicode_literals ...@@ -20,6 +20,7 @@ from __future__ import unicode_literals
from functools import update_wrapper from functools import update_wrapper
from django.utils.decorators import classonlymethod from django.utils.decorators import classonlymethod
from django.views.decorators.csrf import csrf_exempt
from rest_framework import views, generics, mixins from rest_framework import views, generics, mixins
...@@ -89,7 +90,7 @@ class ViewSetMixin(object): ...@@ -89,7 +90,7 @@ class ViewSetMixin(object):
# resolved URL. # resolved URL.
view.cls = cls view.cls = cls
view.suffix = initkwargs.get('suffix', None) view.suffix = initkwargs.get('suffix', None)
return view return csrf_exempt(view)
def initialize_request(self, request, *args, **kargs): def initialize_request(self, request, *args, **kargs):
""" """
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
def foobar(): def foobar():
...@@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel): ...@@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True, target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source') related_name='nullable_source')
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
# From test_validation...
class TestPreSaveValidationExclusions(TestCase):
def test_pre_save_validation_exclusions(self):
"""
Somewhat weird test case to ensure that we don't perform model
validation on read only fields.
"""
obj = ValidationModel.objects.create(blank_validated_field='')
request = factory.put('/', {}, format='json')
view = UpdateValidationModel().as_view()
response = view(request, pk=obj.pk).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
# From test_permissions...
class ModelPermissionsIntegrationTests(TestCase):
def setUp(...):
...
def test_has_put_as_create_permissions(self):
# User only has update permissions - should be able to update an entity.
request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
# But if PUTing to a new entity, permission should be denied.
request = factory.put('/2', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
from rest_framework import serializers
from tests.models import NullableForeignKeySource
class NullableFKSourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
...@@ -98,6 +98,30 @@ class TestViewNamesAndDescriptions(TestCase): ...@@ -98,6 +98,30 @@ class TestViewNamesAndDescriptions(TestCase):
pass pass
self.assertEqual(MockView().get_view_description(), '') self.assertEqual(MockView().get_view_description(), '')
def test_view_description_can_be_promise(self):
"""
Ensure a view may have a docstring that is actually a lazily evaluated
class that can be converted to a string.
See: https://github.com/tomchristie/django-rest-framework/issues/1708
"""
# use a mock object instead of gettext_lazy to ensure that we can't end
# up with a test case string in our l10n catalog
class MockLazyStr(object):
def __init__(self, string):
self.s = string
def __str__(self):
return self.s
def __unicode__(self):
return self.s
class MockView(APIView):
__doc__ = MockLazyStr("a gettext string")
self.assertEqual(MockView().get_view_description(), 'a gettext string')
def test_markdown(self): def test_markdown(self):
""" """
Ensure markdown to HTML works as expected. Ensure markdown to HTML works as expected.
......
""" # """
General serializer field tests. # General serializer field tests.
""" # """
from __future__ import unicode_literals # from __future__ import unicode_literals
import datetime # import datetime
import re # import re
from decimal import Decimal # from decimal import Decimal
from uuid import uuid4 # from uuid import uuid4
from django.core import validators # from django.core import validators
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from django.utils.datastructures import SortedDict # from django.utils.datastructures import SortedDict
from rest_framework import serializers # from rest_framework import serializers
from tests.models import RESTFrameworkModel # from tests.models import RESTFrameworkModel
class TimestampedModel(models.Model): # class TimestampedModel(models.Model):
added = models.DateTimeField(auto_now_add=True) # added = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True) # updated = models.DateTimeField(auto_now=True)
class CharPrimaryKeyModel(models.Model): # class CharPrimaryKeyModel(models.Model):
id = models.CharField(max_length=20, primary_key=True) # id = models.CharField(max_length=20, primary_key=True)
class TimestampedModelSerializer(serializers.ModelSerializer): # class TimestampedModelSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = TimestampedModel # model = TimestampedModel
class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): # class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = CharPrimaryKeyModel # model = CharPrimaryKeyModel
class TimeFieldModel(models.Model): # class TimeFieldModel(models.Model):
clock = models.TimeField() # clock = models.TimeField()
class TimeFieldModelSerializer(serializers.ModelSerializer): # class TimeFieldModelSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = TimeFieldModel # model = TimeFieldModel
SAMPLE_CHOICES = [ # SAMPLE_CHOICES = [
('red', 'Red'), # ('red', 'Red'),
('green', 'Green'), # ('green', 'Green'),
('blue', 'Blue'), # ('blue', 'Blue'),
] # ]
class ChoiceFieldModel(models.Model): # class ChoiceFieldModel(models.Model):
choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255) # choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
class ChoiceFieldModelSerializer(serializers.ModelSerializer): # class ChoiceFieldModelSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ChoiceFieldModel # model = ChoiceFieldModel
class ChoiceFieldModelWithNull(models.Model): # class ChoiceFieldModelWithNull(models.Model):
choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255) # choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer): # class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ChoiceFieldModelWithNull # model = ChoiceFieldModelWithNull
class BasicFieldTests(TestCase): # class BasicFieldTests(TestCase):
def test_auto_now_fields_read_only(self): # def test_auto_now_fields_read_only(self):
""" # """
auto_now and auto_now_add fields should be read_only by default. # auto_now and auto_now_add fields should be read_only by default.
""" # """
serializer = TimestampedModelSerializer() # serializer = TimestampedModelSerializer()
self.assertEqual(serializer.fields['added'].read_only, True) # self.assertEqual(serializer.fields['added'].read_only, True)
def test_auto_pk_fields_read_only(self): # def test_auto_pk_fields_read_only(self):
""" # """
AutoField fields should be read_only by default. # AutoField fields should be read_only by default.
""" # """
serializer = TimestampedModelSerializer() # serializer = TimestampedModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, True) # self.assertEqual(serializer.fields['id'].read_only, True)
def test_non_auto_pk_fields_not_read_only(self): # def test_non_auto_pk_fields_not_read_only(self):
""" # """
PK fields other than AutoField fields should not be read_only by default. # PK fields other than AutoField fields should not be read_only by default.
""" # """
serializer = CharPrimaryKeyModelSerializer() # serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False) # self.assertEqual(serializer.fields['id'].read_only, False)
def test_dict_field_ordering(self): # def test_dict_field_ordering(self):
""" # """
Field should preserve dictionary ordering, if it exists. # Field should preserve dictionary ordering, if it exists.
See: https://github.com/tomchristie/django-rest-framework/issues/832 # See: https://github.com/tomchristie/django-rest-framework/issues/832
""" # """
ret = SortedDict() # ret = SortedDict()
ret['c'] = 1 # ret['c'] = 1
ret['b'] = 1 # ret['b'] = 1
ret['a'] = 1 # ret['a'] = 1
ret['z'] = 1 # ret['z'] = 1
field = serializers.Field() # field = serializers.Field()
keys = list(field.to_native(ret).keys()) # keys = list(field.to_native(ret).keys())
self.assertEqual(keys, ['c', 'b', 'a', 'z']) # self.assertEqual(keys, ['c', 'b', 'a', 'z'])
def test_widget_html_attributes(self): # def test_widget_html_attributes(self):
""" # """
Make sure widget_html() renders the correct attributes # Make sure widget_html() renders the correct attributes
""" # """
r = re.compile('(\S+)=["\']?((?:.(?!["\']?\s+(?:\S+)=|[>"\']))+.)["\']?') # r = re.compile('(\S+)=["\']?((?:.(?!["\']?\s+(?:\S+)=|[>"\']))+.)["\']?')
form = TimeFieldModelSerializer().data # form = TimeFieldModelSerializer().data
attributes = r.findall(form.fields['clock'].widget_html()) # attributes = r.findall(form.fields['clock'].widget_html())
self.assertIn(('name', 'clock'), attributes) # self.assertIn(('name', 'clock'), attributes)
self.assertIn(('id', 'clock'), attributes) # self.assertIn(('id', 'clock'), attributes)
class DateFieldTest(TestCase): # class DateFieldTest(TestCase):
""" # """
Tests for the DateFieldTest from_native() and to_native() behavior # Tests for the DateFieldTest from_native() and to_native() behavior
""" # """
def test_from_native_string(self): # def test_from_native_string(self):
""" # """
Make sure from_native() accepts default iso input formats. # Make sure from_native() accepts default iso input formats.
""" # """
f = serializers.DateField() # f = serializers.DateField()
result_1 = f.from_native('1984-07-31') # result_1 = f.from_native('1984-07-31')
self.assertEqual(datetime.date(1984, 7, 31), result_1) # self.assertEqual(datetime.date(1984, 7, 31), result_1)
def test_from_native_datetime_date(self): # def test_from_native_datetime_date(self):
""" # """
Make sure from_native() accepts a datetime.date instance. # Make sure from_native() accepts a datetime.date instance.
""" # """
f = serializers.DateField() # f = serializers.DateField()
result_1 = f.from_native(datetime.date(1984, 7, 31)) # result_1 = f.from_native(datetime.date(1984, 7, 31))
self.assertEqual(result_1, datetime.date(1984, 7, 31)) # self.assertEqual(result_1, datetime.date(1984, 7, 31))
def test_from_native_custom_format(self): # def test_from_native_custom_format(self):
""" # """
Make sure from_native() accepts custom input formats. # Make sure from_native() accepts custom input formats.
""" # """
f = serializers.DateField(input_formats=['%Y -- %d']) # f = serializers.DateField(input_formats=['%Y -- %d'])
result = f.from_native('1984 -- 31') # result = f.from_native('1984 -- 31')
self.assertEqual(datetime.date(1984, 1, 31), result) # self.assertEqual(datetime.date(1984, 1, 31), result)
def test_from_native_invalid_default_on_custom_format(self): # def test_from_native_invalid_default_on_custom_format(self):
""" # """
Make sure from_native() don't accept default formats if custom format is preset # Make sure from_native() don't accept default formats if custom format is preset
""" # """
f = serializers.DateField(input_formats=['%Y -- %d']) # f = serializers.DateField(input_formats=['%Y -- %d'])
try: # try:
f.from_native('1984-07-31') # f.from_native('1984-07-31')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"]) # self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_from_native_empty(self): # def test_from_native_empty(self):
""" # """
Make sure from_native() returns None on empty param. # Make sure from_native() returns None on empty param.
""" # """
f = serializers.DateField() # f = serializers.DateField()
result = f.from_native('') # result = f.from_native('')
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_from_native_none(self): # def test_from_native_none(self):
""" # """
Make sure from_native() returns None on None param. # Make sure from_native() returns None on None param.
""" # """
f = serializers.DateField() # f = serializers.DateField()
result = f.from_native(None) # result = f.from_native(None)
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_from_native_invalid_date(self): # def test_from_native_invalid_date(self):
""" # """
Make sure from_native() raises a ValidationError on passing an invalid date. # Make sure from_native() raises a ValidationError on passing an invalid date.
""" # """
f = serializers.DateField() # f = serializers.DateField()
try: # try:
f.from_native('1984-13-31') # f.from_native('1984-13-31')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) # self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_from_native_invalid_format(self): # def test_from_native_invalid_format(self):
""" # """
Make sure from_native() raises a ValidationError on passing an invalid format. # Make sure from_native() raises a ValidationError on passing an invalid format.
""" # """
f = serializers.DateField() # f = serializers.DateField()
try: # try:
f.from_native('1984 -- 31') # f.from_native('1984 -- 31')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) # self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_to_native(self): # def test_to_native(self):
""" # """
Make sure to_native() returns datetime as default. # Make sure to_native() returns datetime as default.
""" # """
f = serializers.DateField() # f = serializers.DateField()
result_1 = f.to_native(datetime.date(1984, 7, 31)) # result_1 = f.to_native(datetime.date(1984, 7, 31))
self.assertEqual(datetime.date(1984, 7, 31), result_1) # self.assertEqual(datetime.date(1984, 7, 31), result_1)
def test_to_native_iso(self): # def test_to_native_iso(self):
""" # """
Make sure to_native() with 'iso-8601' returns iso formated date. # Make sure to_native() with 'iso-8601' returns iso formated date.
""" # """
f = serializers.DateField(format='iso-8601') # f = serializers.DateField(format='iso-8601')
result_1 = f.to_native(datetime.date(1984, 7, 31)) # result_1 = f.to_native(datetime.date(1984, 7, 31))
self.assertEqual('1984-07-31', result_1) # self.assertEqual('1984-07-31', result_1)
def test_to_native_custom_format(self): # def test_to_native_custom_format(self):
""" # """
Make sure to_native() returns correct custom format. # Make sure to_native() returns correct custom format.
""" # """
f = serializers.DateField(format="%Y - %m.%d") # f = serializers.DateField(format="%Y - %m.%d")
result_1 = f.to_native(datetime.date(1984, 7, 31)) # result_1 = f.to_native(datetime.date(1984, 7, 31))
self.assertEqual('1984 - 07.31', result_1) # self.assertEqual('1984 - 07.31', result_1)
def test_to_native_none(self): # def test_to_native_none(self):
""" # """
Make sure from_native() returns None on None param. # Make sure from_native() returns None on None param.
""" # """
f = serializers.DateField(required=False) # f = serializers.DateField(required=False)
self.assertEqual(None, f.to_native(None)) # self.assertEqual(None, f.to_native(None))
class DateTimeFieldTest(TestCase): # class DateTimeFieldTest(TestCase):
""" # """
Tests for the DateTimeField from_native() and to_native() behavior # Tests for the DateTimeField from_native() and to_native() behavior
""" # """
def test_from_native_string(self): # def test_from_native_string(self):
""" # """
Make sure from_native() accepts default iso input formats. # Make sure from_native() accepts default iso input formats.
""" # """
f = serializers.DateTimeField() # f = serializers.DateTimeField()
result_1 = f.from_native('1984-07-31 04:31') # result_1 = f.from_native('1984-07-31 04:31')
result_2 = f.from_native('1984-07-31 04:31:59') # result_2 = f.from_native('1984-07-31 04:31:59')
result_3 = f.from_native('1984-07-31 04:31:59.000200') # result_3 = f.from_native('1984-07-31 04:31:59.000200')
self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1) # self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2) # self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3) # self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
def test_from_native_datetime_datetime(self): # def test_from_native_datetime_datetime(self):
""" # """
Make sure from_native() accepts a datetime.datetime instance. # Make sure from_native() accepts a datetime.datetime instance.
""" # """
f = serializers.DateTimeField() # f = serializers.DateTimeField()
result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31)) # result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) # result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) # result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31)) # self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59)) # self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) # self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
def test_from_native_custom_format(self): # def test_from_native_custom_format(self):
""" # """
Make sure from_native() accepts custom input formats. # Make sure from_native() accepts custom input formats.
""" # """
f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) # f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
result = f.from_native('1984 -- 04:59') # result = f.from_native('1984 -- 04:59')
self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result) # self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
def test_from_native_invalid_default_on_custom_format(self): # def test_from_native_invalid_default_on_custom_format(self):
""" # """
Make sure from_native() don't accept default formats if custom format is preset # Make sure from_native() don't accept default formats if custom format is preset
""" # """
f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) # f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
try: # try:
f.from_native('1984-07-31 04:31:59') # f.from_native('1984-07-31 04:31:59')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"]) # self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_from_native_empty(self): # def test_from_native_empty(self):
""" # """
Make sure from_native() returns None on empty param. # Make sure from_native() returns None on empty param.
""" # """
f = serializers.DateTimeField() # f = serializers.DateTimeField()
result = f.from_native('') # result = f.from_native('')
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_from_native_none(self): # def test_from_native_none(self):
""" # """
Make sure from_native() returns None on None param. # Make sure from_native() returns None on None param.
""" # """
f = serializers.DateTimeField() # f = serializers.DateTimeField()
result = f.from_native(None) # result = f.from_native(None)
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_from_native_invalid_datetime(self): # def test_from_native_invalid_datetime(self):
""" # """
Make sure from_native() raises a ValidationError on passing an invalid datetime. # Make sure from_native() raises a ValidationError on passing an invalid datetime.
""" # """
f = serializers.DateTimeField() # f = serializers.DateTimeField()
try: # try:
f.from_native('04:61:59') # f.from_native('04:61:59')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " # self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"]) # "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_from_native_invalid_format(self): # def test_from_native_invalid_format(self):
""" # """
Make sure from_native() raises a ValidationError on passing an invalid format. # Make sure from_native() raises a ValidationError on passing an invalid format.
""" # """
f = serializers.DateTimeField() # f = serializers.DateTimeField()
try: # try:
f.from_native('04 -- 31') # f.from_native('04 -- 31')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " # self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
"YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"]) # "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_to_native(self): # def test_to_native(self):
""" # """
Make sure to_native() returns isoformat as default. # Make sure to_native() returns isoformat as default.
""" # """
f = serializers.DateTimeField() # f = serializers.DateTimeField()
result_1 = f.to_native(datetime.datetime(1984, 7, 31)) # result_1 = f.to_native(datetime.datetime(1984, 7, 31))
result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) # result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) # result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) # result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
self.assertEqual(datetime.datetime(1984, 7, 31), result_1) # self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) # self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) # self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) # self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
def test_to_native_iso(self): # def test_to_native_iso(self):
""" # """
Make sure to_native() with format=iso-8601 returns iso formatted datetime. # Make sure to_native() with format=iso-8601 returns iso formatted datetime.
""" # """
f = serializers.DateTimeField(format='iso-8601') # f = serializers.DateTimeField(format='iso-8601')
result_1 = f.to_native(datetime.datetime(1984, 7, 31)) # result_1 = f.to_native(datetime.datetime(1984, 7, 31))
result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) # result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) # result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) # result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
self.assertEqual('1984-07-31T00:00:00', result_1) # self.assertEqual('1984-07-31T00:00:00', result_1)
self.assertEqual('1984-07-31T04:31:00', result_2) # self.assertEqual('1984-07-31T04:31:00', result_2)
self.assertEqual('1984-07-31T04:31:59', result_3) # self.assertEqual('1984-07-31T04:31:59', result_3)
self.assertEqual('1984-07-31T04:31:59.000200', result_4) # self.assertEqual('1984-07-31T04:31:59.000200', result_4)
def test_to_native_custom_format(self): # def test_to_native_custom_format(self):
""" # """
Make sure to_native() returns correct custom format. # Make sure to_native() returns correct custom format.
""" # """
f = serializers.DateTimeField(format="%Y - %H:%M") # f = serializers.DateTimeField(format="%Y - %H:%M")
result_1 = f.to_native(datetime.datetime(1984, 7, 31)) # result_1 = f.to_native(datetime.datetime(1984, 7, 31))
result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) # result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) # result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) # result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
self.assertEqual('1984 - 00:00', result_1) # self.assertEqual('1984 - 00:00', result_1)
self.assertEqual('1984 - 04:31', result_2) # self.assertEqual('1984 - 04:31', result_2)
self.assertEqual('1984 - 04:31', result_3) # self.assertEqual('1984 - 04:31', result_3)
self.assertEqual('1984 - 04:31', result_4) # self.assertEqual('1984 - 04:31', result_4)
def test_to_native_none(self): # def test_to_native_none(self):
""" # """
Make sure from_native() returns None on None param. # Make sure from_native() returns None on None param.
""" # """
f = serializers.DateTimeField(required=False) # f = serializers.DateTimeField(required=False)
self.assertEqual(None, f.to_native(None)) # self.assertEqual(None, f.to_native(None))
class TimeFieldTest(TestCase): # class TimeFieldTest(TestCase):
""" # """
Tests for the TimeField from_native() and to_native() behavior # Tests for the TimeField from_native() and to_native() behavior
""" # """
def test_from_native_string(self): # def test_from_native_string(self):
""" # """
Make sure from_native() accepts default iso input formats. # Make sure from_native() accepts default iso input formats.
""" # """
f = serializers.TimeField() # f = serializers.TimeField()
result_1 = f.from_native('04:31') # result_1 = f.from_native('04:31')
result_2 = f.from_native('04:31:59') # result_2 = f.from_native('04:31:59')
result_3 = f.from_native('04:31:59.000200') # result_3 = f.from_native('04:31:59.000200')
self.assertEqual(datetime.time(4, 31), result_1) # self.assertEqual(datetime.time(4, 31), result_1)
self.assertEqual(datetime.time(4, 31, 59), result_2) # self.assertEqual(datetime.time(4, 31, 59), result_2)
self.assertEqual(datetime.time(4, 31, 59, 200), result_3) # self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
def test_from_native_datetime_time(self): # def test_from_native_datetime_time(self):
""" # """
Make sure from_native() accepts a datetime.time instance. # Make sure from_native() accepts a datetime.time instance.
""" # """
f = serializers.TimeField() # f = serializers.TimeField()
result_1 = f.from_native(datetime.time(4, 31)) # result_1 = f.from_native(datetime.time(4, 31))
result_2 = f.from_native(datetime.time(4, 31, 59)) # result_2 = f.from_native(datetime.time(4, 31, 59))
result_3 = f.from_native(datetime.time(4, 31, 59, 200)) # result_3 = f.from_native(datetime.time(4, 31, 59, 200))
self.assertEqual(result_1, datetime.time(4, 31)) # self.assertEqual(result_1, datetime.time(4, 31))
self.assertEqual(result_2, datetime.time(4, 31, 59)) # self.assertEqual(result_2, datetime.time(4, 31, 59))
self.assertEqual(result_3, datetime.time(4, 31, 59, 200)) # self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
def test_from_native_custom_format(self): # def test_from_native_custom_format(self):
""" # """
Make sure from_native() accepts custom input formats. # Make sure from_native() accepts custom input formats.
""" # """
f = serializers.TimeField(input_formats=['%H -- %M']) # f = serializers.TimeField(input_formats=['%H -- %M'])
result = f.from_native('04 -- 31') # result = f.from_native('04 -- 31')
self.assertEqual(datetime.time(4, 31), result) # self.assertEqual(datetime.time(4, 31), result)
def test_from_native_invalid_default_on_custom_format(self): # def test_from_native_invalid_default_on_custom_format(self):
""" # """
Make sure from_native() don't accept default formats if custom format is preset # Make sure from_native() don't accept default formats if custom format is preset
""" # """
f = serializers.TimeField(input_formats=['%H -- %M']) # f = serializers.TimeField(input_formats=['%H -- %M'])
try: # try:
f.from_native('04:31:59') # f.from_native('04:31:59')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"]) # self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_from_native_empty(self): # def test_from_native_empty(self):
""" # """
Make sure from_native() returns None on empty param. # Make sure from_native() returns None on empty param.
""" # """
f = serializers.TimeField() # f = serializers.TimeField()
result = f.from_native('') # result = f.from_native('')
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_from_native_none(self): # def test_from_native_none(self):
""" # """
Make sure from_native() returns None on None param. # Make sure from_native() returns None on None param.
""" # """
f = serializers.TimeField() # f = serializers.TimeField()
result = f.from_native(None) # result = f.from_native(None)
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_from_native_invalid_time(self): # def test_from_native_invalid_time(self):
""" # """
Make sure from_native() raises a ValidationError on passing an invalid time. # Make sure from_native() raises a ValidationError on passing an invalid time.
""" # """
f = serializers.TimeField() # f = serializers.TimeField()
try: # try:
f.from_native('04:61:59') # f.from_native('04:61:59')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " # self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
"hh:mm[:ss[.uuuuuu]]"]) # "hh:mm[:ss[.uuuuuu]]"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_from_native_invalid_format(self): # def test_from_native_invalid_format(self):
""" # """
Make sure from_native() raises a ValidationError on passing an invalid format. # Make sure from_native() raises a ValidationError on passing an invalid format.
""" # """
f = serializers.TimeField() # f = serializers.TimeField()
try: # try:
f.from_native('04 -- 31') # f.from_native('04 -- 31')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " # self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
"hh:mm[:ss[.uuuuuu]]"]) # "hh:mm[:ss[.uuuuuu]]"])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_to_native(self): # def test_to_native(self):
""" # """
Make sure to_native() returns time object as default. # Make sure to_native() returns time object as default.
""" # """
f = serializers.TimeField() # f = serializers.TimeField()
result_1 = f.to_native(datetime.time(4, 31)) # result_1 = f.to_native(datetime.time(4, 31))
result_2 = f.to_native(datetime.time(4, 31, 59)) # result_2 = f.to_native(datetime.time(4, 31, 59))
result_3 = f.to_native(datetime.time(4, 31, 59, 200)) # result_3 = f.to_native(datetime.time(4, 31, 59, 200))
self.assertEqual(datetime.time(4, 31), result_1) # self.assertEqual(datetime.time(4, 31), result_1)
self.assertEqual(datetime.time(4, 31, 59), result_2) # self.assertEqual(datetime.time(4, 31, 59), result_2)
self.assertEqual(datetime.time(4, 31, 59, 200), result_3) # self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
def test_to_native_iso(self): # def test_to_native_iso(self):
""" # """
Make sure to_native() with format='iso-8601' returns iso formatted time. # Make sure to_native() with format='iso-8601' returns iso formatted time.
""" # """
f = serializers.TimeField(format='iso-8601') # f = serializers.TimeField(format='iso-8601')
result_1 = f.to_native(datetime.time(4, 31)) # result_1 = f.to_native(datetime.time(4, 31))
result_2 = f.to_native(datetime.time(4, 31, 59)) # result_2 = f.to_native(datetime.time(4, 31, 59))
result_3 = f.to_native(datetime.time(4, 31, 59, 200)) # result_3 = f.to_native(datetime.time(4, 31, 59, 200))
self.assertEqual('04:31:00', result_1) # self.assertEqual('04:31:00', result_1)
self.assertEqual('04:31:59', result_2) # self.assertEqual('04:31:59', result_2)
self.assertEqual('04:31:59.000200', result_3) # self.assertEqual('04:31:59.000200', result_3)
def test_to_native_custom_format(self): # def test_to_native_custom_format(self):
""" # """
Make sure to_native() returns correct custom format. # Make sure to_native() returns correct custom format.
""" # """
f = serializers.TimeField(format="%H - %S [%f]") # f = serializers.TimeField(format="%H - %S [%f]")
result_1 = f.to_native(datetime.time(4, 31)) # result_1 = f.to_native(datetime.time(4, 31))
result_2 = f.to_native(datetime.time(4, 31, 59)) # result_2 = f.to_native(datetime.time(4, 31, 59))
result_3 = f.to_native(datetime.time(4, 31, 59, 200)) # result_3 = f.to_native(datetime.time(4, 31, 59, 200))
self.assertEqual('04 - 00 [000000]', result_1) # self.assertEqual('04 - 00 [000000]', result_1)
self.assertEqual('04 - 59 [000000]', result_2) # self.assertEqual('04 - 59 [000000]', result_2)
self.assertEqual('04 - 59 [000200]', result_3) # self.assertEqual('04 - 59 [000200]', result_3)
class DecimalFieldTest(TestCase): # class DecimalFieldTest(TestCase):
""" # """
Tests for the DecimalField from_native() and to_native() behavior # Tests for the DecimalField from_native() and to_native() behavior
""" # """
def test_from_native_string(self): # def test_from_native_string(self):
""" # """
Make sure from_native() accepts string values # Make sure from_native() accepts string values
""" # """
f = serializers.DecimalField() # f = serializers.DecimalField()
result_1 = f.from_native('9000') # result_1 = f.from_native('9000')
result_2 = f.from_native('1.00000001') # result_2 = f.from_native('1.00000001')
self.assertEqual(Decimal('9000'), result_1) # self.assertEqual(Decimal('9000'), result_1)
self.assertEqual(Decimal('1.00000001'), result_2) # self.assertEqual(Decimal('1.00000001'), result_2)
def test_from_native_invalid_string(self): # def test_from_native_invalid_string(self):
""" # """
Make sure from_native() raises ValidationError on passing invalid string # Make sure from_native() raises ValidationError on passing invalid string
""" # """
f = serializers.DecimalField() # f = serializers.DecimalField()
try: # try:
f.from_native('123.45.6') # f.from_native('123.45.6')
except validators.ValidationError as e: # except validators.ValidationError as e:
self.assertEqual(e.messages, ["Enter a number."]) # self.assertEqual(e.messages, ["Enter a number."])
else: # else:
self.fail("ValidationError was not properly raised") # self.fail("ValidationError was not properly raised")
def test_from_native_integer(self): # def test_from_native_integer(self):
""" # """
Make sure from_native() accepts integer values # Make sure from_native() accepts integer values
""" # """
f = serializers.DecimalField() # f = serializers.DecimalField()
result = f.from_native(9000) # result = f.from_native(9000)
self.assertEqual(Decimal('9000'), result) # self.assertEqual(Decimal('9000'), result)
def test_from_native_float(self): # def test_from_native_float(self):
""" # """
Make sure from_native() accepts float values # Make sure from_native() accepts float values
""" # """
f = serializers.DecimalField() # f = serializers.DecimalField()
result = f.from_native(1.00000001) # result = f.from_native(1.00000001)
self.assertEqual(Decimal('1.00000001'), result) # self.assertEqual(Decimal('1.00000001'), result)
def test_from_native_empty(self): # def test_from_native_empty(self):
""" # """
Make sure from_native() returns None on empty param. # Make sure from_native() returns None on empty param.
""" # """
f = serializers.DecimalField() # f = serializers.DecimalField()
result = f.from_native('') # result = f.from_native('')
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_from_native_none(self): # def test_from_native_none(self):
""" # """
Make sure from_native() returns None on None param. # Make sure from_native() returns None on None param.
""" # """
f = serializers.DecimalField() # f = serializers.DecimalField()
result = f.from_native(None) # result = f.from_native(None)
self.assertEqual(result, None) # self.assertEqual(result, None)
def test_to_native(self): # def test_to_native(self):
""" # """
Make sure to_native() returns Decimal as string. # Make sure to_native() returns Decimal as string.
""" # """
f = serializers.DecimalField() # f = serializers.DecimalField()
result_1 = f.to_native(Decimal('9000')) # result_1 = f.to_native(Decimal('9000'))
result_2 = f.to_native(Decimal('1.00000001')) # result_2 = f.to_native(Decimal('1.00000001'))
self.assertEqual(Decimal('9000'), result_1) # self.assertEqual(Decimal('9000'), result_1)
self.assertEqual(Decimal('1.00000001'), result_2) # self.assertEqual(Decimal('1.00000001'), result_2)
def test_to_native_none(self): # def test_to_native_none(self):
""" # """
Make sure from_native() returns None on None param. # Make sure from_native() returns None on None param.
""" # """
f = serializers.DecimalField(required=False) # f = serializers.DecimalField(required=False)
self.assertEqual(None, f.to_native(None)) # self.assertEqual(None, f.to_native(None))
def test_valid_serialization(self): # def test_valid_serialization(self):
""" # """
Make sure the serializer works correctly # Make sure the serializer works correctly
""" # """
class DecimalSerializer(serializers.Serializer): # class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=9010, # decimal_field = serializers.DecimalField(max_value=9010,
min_value=9000, # min_value=9000,
max_digits=6, # max_digits=6,
decimal_places=2) # decimal_places=2)
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) # self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) # self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) # self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) # self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) # self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) # self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
def test_raise_max_value(self): # def test_raise_max_value(self):
""" # """
Make sure max_value violations raises ValidationError # Make sure max_value violations raises ValidationError
""" # """
class DecimalSerializer(serializers.Serializer): # class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=100) # decimal_field = serializers.DecimalField(max_value=100)
s = DecimalSerializer(data={'decimal_field': '123'}) # s = DecimalSerializer(data={'decimal_field': '123'})
self.assertFalse(s.is_valid()) # self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']}) # self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
def test_raise_min_value(self): # def test_raise_min_value(self):
""" # """
Make sure min_value violations raises ValidationError # Make sure min_value violations raises ValidationError
""" # """
class DecimalSerializer(serializers.Serializer): # class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(min_value=100) # decimal_field = serializers.DecimalField(min_value=100)
s = DecimalSerializer(data={'decimal_field': '99'}) # s = DecimalSerializer(data={'decimal_field': '99'})
self.assertFalse(s.is_valid()) # self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) # self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
def test_raise_max_digits(self): # def test_raise_max_digits(self):
""" # """
Make sure max_digits violations raises ValidationError # Make sure max_digits violations raises ValidationError
""" # """
class DecimalSerializer(serializers.Serializer): # class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=5) # decimal_field = serializers.DecimalField(max_digits=5)
s = DecimalSerializer(data={'decimal_field': '123.456'}) # s = DecimalSerializer(data={'decimal_field': '123.456'})
self.assertFalse(s.is_valid()) # self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) # self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
def test_raise_max_decimal_places(self): # def test_raise_max_decimal_places(self):
""" # """
Make sure max_decimal_places violations raises ValidationError # Make sure max_decimal_places violations raises ValidationError
""" # """
class DecimalSerializer(serializers.Serializer): # class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(decimal_places=3) # decimal_field = serializers.DecimalField(decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '123.4567'}) # s = DecimalSerializer(data={'decimal_field': '123.4567'})
self.assertFalse(s.is_valid()) # self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) # self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
def test_raise_max_whole_digits(self): # def test_raise_max_whole_digits(self):
""" # """
Make sure max_whole_digits violations raises ValidationError # Make sure max_whole_digits violations raises ValidationError
""" # """
class DecimalSerializer(serializers.Serializer): # class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) # decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '12345.6'}) # s = DecimalSerializer(data={'decimal_field': '12345.6'})
self.assertFalse(s.is_valid()) # self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) # self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
class ChoiceFieldTests(TestCase): # class ChoiceFieldTests(TestCase):
""" # """
Tests for the ChoiceField options generator # Tests for the ChoiceField options generator
""" # """
def test_choices_required(self): # def test_choices_required(self):
""" # """
Make sure proper choices are rendered if field is required # Make sure proper choices are rendered if field is required
""" # """
f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES) # f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
self.assertEqual(f.choices, SAMPLE_CHOICES) # self.assertEqual(f.choices, SAMPLE_CHOICES)
def test_choices_not_required(self): # def test_choices_not_required(self):
""" # """
Make sure proper choices (plus blank) are rendered if the field isn't required # Make sure proper choices (plus blank) are rendered if the field isn't required
""" # """
f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) # f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES) # self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
def test_blank_choice_display(self): # def test_blank_choice_display(self):
blank = 'No Preference' # blank = 'No Preference'
f = serializers.ChoiceField( # f = serializers.ChoiceField(
required=False, # required=False,
choices=SAMPLE_CHOICES, # choices=SAMPLE_CHOICES,
blank_display_value=blank, # blank_display_value=blank,
) # )
self.assertEqual(f.choices, [('', blank)] + SAMPLE_CHOICES) # self.assertEqual(f.choices, [('', blank)] + SAMPLE_CHOICES)
def test_invalid_choice_model(self): # def test_invalid_choice_model(self):
s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'}) # s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
self.assertFalse(s.is_valid()) # self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']}) # self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
self.assertEqual(s.data['choice'], '') # self.assertEqual(s.data['choice'], '')
def test_empty_choice_model(self): # def test_empty_choice_model(self):
""" # """
Test that the 'empty' value is correctly passed and used depending on # Test that the 'empty' value is correctly passed and used depending on
the 'null' property on the model field. # the 'null' property on the model field.
""" # """
s = ChoiceFieldModelSerializer(data={'choice': ''}) # s = ChoiceFieldModelSerializer(data={'choice': ''})
self.assertTrue(s.is_valid()) # self.assertTrue(s.is_valid())
self.assertEqual(s.data['choice'], '') # self.assertEqual(s.data['choice'], '')
s = ChoiceFieldModelWithNullSerializer(data={'choice': ''}) # s = ChoiceFieldModelWithNullSerializer(data={'choice': ''})
self.assertTrue(s.is_valid()) # self.assertTrue(s.is_valid())
self.assertEqual(s.data['choice'], None) # self.assertEqual(s.data['choice'], None)
def test_from_native_empty(self): # def test_from_native_empty(self):
""" # """
Make sure from_native() returns an empty string on empty param by default. # Make sure from_native() returns an empty string on empty param by default.
""" # """
f = serializers.ChoiceField(choices=SAMPLE_CHOICES) # f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
self.assertEqual(f.from_native(''), '') # self.assertEqual(f.from_native(''), '')
self.assertEqual(f.from_native(None), '') # self.assertEqual(f.from_native(None), '')
def test_from_native_empty_override(self): # def test_from_native_empty_override(self):
""" # """
Make sure you can override from_native() behavior regarding empty values. # Make sure you can override from_native() behavior regarding empty values.
""" # """
f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None) # f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
self.assertEqual(f.from_native(''), None) # self.assertEqual(f.from_native(''), None)
self.assertEqual(f.from_native(None), None) # self.assertEqual(f.from_native(None), None)
def test_metadata_choices(self): # def test_metadata_choices(self):
""" # """
Make sure proper choices are included in the field's metadata. # Make sure proper choices are included in the field's metadata.
""" # """
choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES] # choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES]
f = serializers.ChoiceField(choices=SAMPLE_CHOICES) # f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
self.assertEqual(f.metadata()['choices'], choices) # self.assertEqual(f.metadata()['choices'], choices)
def test_metadata_choices_not_required(self): # def test_metadata_choices_not_required(self):
""" # """
Make sure proper choices are included in the field's metadata. # Make sure proper choices are included in the field's metadata.
""" # """
choices = [{'value': v, 'display_name': n} # choices = [{'value': v, 'display_name': n}
for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES] # for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES]
f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) # f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
self.assertEqual(f.metadata()['choices'], choices) # self.assertEqual(f.metadata()['choices'], choices)
class EmailFieldTests(TestCase): # class EmailFieldTests(TestCase):
""" # """
Tests for EmailField attribute values # Tests for EmailField attribute values
""" # """
class EmailFieldModel(RESTFrameworkModel): # class EmailFieldModel(RESTFrameworkModel):
email_field = models.EmailField(blank=True) # email_field = models.EmailField(blank=True)
class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel): # class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel):
email_field = models.EmailField(max_length=150, blank=True) # email_field = models.EmailField(max_length=150, blank=True)
def test_default_model_value(self): # def test_default_model_value(self):
class EmailFieldSerializer(serializers.ModelSerializer): # class EmailFieldSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = self.EmailFieldModel # model = self.EmailFieldModel
serializer = EmailFieldSerializer(data={}) # serializer = EmailFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75) # self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75)
def test_given_model_value(self): # def test_given_model_value(self):
class EmailFieldSerializer(serializers.ModelSerializer): # class EmailFieldSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = self.EmailFieldWithGivenMaxLengthModel # model = self.EmailFieldWithGivenMaxLengthModel
serializer = EmailFieldSerializer(data={}) # serializer = EmailFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150) # self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150)
def test_given_serializer_value(self): # def test_given_serializer_value(self):
class EmailFieldSerializer(serializers.ModelSerializer): # class EmailFieldSerializer(serializers.ModelSerializer):
email_field = serializers.EmailField(source='email_field', max_length=20, required=False) # email_field = serializers.EmailField(source='email_field', max_length=20, required=False)
class Meta: # class Meta:
model = self.EmailFieldModel # model = self.EmailFieldModel
serializer = EmailFieldSerializer(data={}) # serializer = EmailFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20) # self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20)
class SlugFieldTests(TestCase): # class SlugFieldTests(TestCase):
""" # """
Tests for SlugField attribute values # Tests for SlugField attribute values
""" # """
class SlugFieldModel(RESTFrameworkModel): # class SlugFieldModel(RESTFrameworkModel):
slug_field = models.SlugField(blank=True) # slug_field = models.SlugField(blank=True)
class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel): # class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel):
slug_field = models.SlugField(max_length=84, blank=True) # slug_field = models.SlugField(max_length=84, blank=True)
def test_default_model_value(self): # def test_default_model_value(self):
class SlugFieldSerializer(serializers.ModelSerializer): # class SlugFieldSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = self.SlugFieldModel # model = self.SlugFieldModel
serializer = SlugFieldSerializer(data={}) # serializer = SlugFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50) # self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50)
def test_given_model_value(self): # def test_given_model_value(self):
class SlugFieldSerializer(serializers.ModelSerializer): # class SlugFieldSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = self.SlugFieldWithGivenMaxLengthModel # model = self.SlugFieldWithGivenMaxLengthModel
serializer = SlugFieldSerializer(data={}) # serializer = SlugFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84) # self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84)
def test_given_serializer_value(self): # def test_given_serializer_value(self):
class SlugFieldSerializer(serializers.ModelSerializer): # class SlugFieldSerializer(serializers.ModelSerializer):
slug_field = serializers.SlugField(source='slug_field', # slug_field = serializers.SlugField(source='slug_field',
max_length=20, required=False) # max_length=20, required=False)
class Meta: # class Meta:
model = self.SlugFieldModel # model = self.SlugFieldModel
serializer = SlugFieldSerializer(data={}) # serializer = SlugFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['slug_field'], # self.assertEqual(getattr(serializer.fields['slug_field'],
'max_length'), 20) # 'max_length'), 20)
def test_invalid_slug(self): # def test_invalid_slug(self):
""" # """
Make sure an invalid slug raises ValidationError # Make sure an invalid slug raises ValidationError
""" # """
class SlugFieldSerializer(serializers.ModelSerializer): # class SlugFieldSerializer(serializers.ModelSerializer):
slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True) # slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True)
class Meta: # class Meta:
model = self.SlugFieldModel # model = self.SlugFieldModel
s = SlugFieldSerializer(data={'slug_field': 'a b'}) # s = SlugFieldSerializer(data={'slug_field': 'a b'})
self.assertEqual(s.is_valid(), False) # self.assertEqual(s.is_valid(), False)
self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) # self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
class URLFieldTests(TestCase): # class URLFieldTests(TestCase):
""" # """
Tests for URLField attribute values. # Tests for URLField attribute values.
(Includes test for #1210, checking that validators can be overridden.) # (Includes test for #1210, checking that validators can be overridden.)
""" # """
class URLFieldModel(RESTFrameworkModel): # class URLFieldModel(RESTFrameworkModel):
url_field = models.URLField(blank=True) # url_field = models.URLField(blank=True)
class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel): # class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel):
url_field = models.URLField(max_length=128, blank=True) # url_field = models.URLField(max_length=128, blank=True)
def test_default_model_value(self): # def test_default_model_value(self):
class URLFieldSerializer(serializers.ModelSerializer): # class URLFieldSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = self.URLFieldModel # model = self.URLFieldModel
serializer = URLFieldSerializer(data={}) # serializer = URLFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['url_field'], # self.assertEqual(getattr(serializer.fields['url_field'],
'max_length'), 200) # 'max_length'), 200)
def test_given_model_value(self): # def test_given_model_value(self):
class URLFieldSerializer(serializers.ModelSerializer): # class URLFieldSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = self.URLFieldWithGivenMaxLengthModel # model = self.URLFieldWithGivenMaxLengthModel
serializer = URLFieldSerializer(data={}) # serializer = URLFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['url_field'], # self.assertEqual(getattr(serializer.fields['url_field'],
'max_length'), 128) # 'max_length'), 128)
def test_given_serializer_value(self): # def test_given_serializer_value(self):
class URLFieldSerializer(serializers.ModelSerializer): # class URLFieldSerializer(serializers.ModelSerializer):
url_field = serializers.URLField(source='url_field', # url_field = serializers.URLField(source='url_field',
max_length=20, required=False) # max_length=20, required=False)
class Meta: # class Meta:
model = self.URLFieldWithGivenMaxLengthModel # model = self.URLFieldWithGivenMaxLengthModel
serializer = URLFieldSerializer(data={}) # serializer = URLFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(getattr(serializer.fields['url_field'], # self.assertEqual(getattr(serializer.fields['url_field'],
'max_length'), 20) # 'max_length'), 20)
def test_validators_can_be_overridden(self): # def test_validators_can_be_overridden(self):
url_field = serializers.URLField(validators=[]) # url_field = serializers.URLField(validators=[])
validators = url_field.validators # validators = url_field.validators
self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators') # self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators')
class FieldMetadata(TestCase): # class FieldMetadata(TestCase):
def setUp(self): # def setUp(self):
self.required_field = serializers.Field() # self.required_field = serializers.Field()
self.required_field.label = uuid4().hex # self.required_field.label = uuid4().hex
self.required_field.required = True # self.required_field.required = True
self.optional_field = serializers.Field() # self.optional_field = serializers.Field()
self.optional_field.label = uuid4().hex # self.optional_field.label = uuid4().hex
self.optional_field.required = False # self.optional_field.required = False
def test_required(self): # def test_required(self):
self.assertEqual(self.required_field.metadata()['required'], True) # self.assertEqual(self.required_field.metadata()['required'], True)
def test_optional(self): # def test_optional(self):
self.assertEqual(self.optional_field.metadata()['required'], False) # self.assertEqual(self.optional_field.metadata()['required'], False)
def test_label(self): # def test_label(self):
for field in (self.required_field, self.optional_field): # for field in (self.required_field, self.optional_field):
self.assertEqual(field.metadata()['label'], field.label) # self.assertEqual(field.metadata()['label'], field.label)
class FieldCallableDefault(TestCase): # class FieldCallableDefault(TestCase):
def setUp(self): # def setUp(self):
self.simple_callable = lambda: 'foo bar' # self.simple_callable = lambda: 'foo bar'
def test_default_can_be_simple_callable(self): # def test_default_can_be_simple_callable(self):
""" # """
Ensure that the 'default' argument can also be a simple callable. # Ensure that the 'default' argument can also be a simple callable.
""" # """
field = serializers.WritableField(default=self.simple_callable) # field = serializers.WritableField(default=self.simple_callable)
into = {} # into = {}
field.field_from_native({}, {}, 'field', into) # field.field_from_native({}, {}, 'field', into)
self.assertEqual(into, {'field': 'foo bar'}) # self.assertEqual(into, {'field': 'foo bar'})
class CustomIntegerField(TestCase): # class CustomIntegerField(TestCase):
""" # """
Test that custom fields apply min_value and max_value constraints # Test that custom fields apply min_value and max_value constraints
""" # """
def test_custom_fields_can_be_validated_for_value(self): # def test_custom_fields_can_be_validated_for_value(self):
class MoneyField(models.PositiveIntegerField): # class MoneyField(models.PositiveIntegerField):
pass # pass
class EntryModel(models.Model): # class EntryModel(models.Model):
bank = MoneyField(validators=[validators.MaxValueValidator(100)]) # bank = MoneyField(validators=[validators.MaxValueValidator(100)])
class EntrySerializer(serializers.ModelSerializer): # class EntrySerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = EntryModel # model = EntryModel
# entry = EntryModel(bank=1)
# serializer = EntrySerializer(entry, data={"bank": 11})
# self.assertTrue(serializer.is_valid())
# serializer = EntrySerializer(entry, data={"bank": -1})
# self.assertFalse(serializer.is_valid())
# serializer = EntrySerializer(entry, data={"bank": 101})
# self.assertFalse(serializer.is_valid())
# class BooleanField(TestCase):
# """
# Tests for BooleanField
# """
# def test_boolean_required(self):
# class BooleanRequiredSerializer(serializers.Serializer):
# bool_field = serializers.BooleanField(required=True)
# self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
# class SerializerMethodFieldTest(TestCase):
# """
# Tests for the SerializerMethodField field_to_native() behavior
# """
# class SerializerTest(serializers.Serializer):
# def get_my_test(self, obj):
# return obj.my_test[0:5]
# class ModelCharField(TestCase):
# """
# Tests for CharField
# """
# def test_none_serializing(self):
# class CharFieldSerializer(serializers.Serializer):
# char = serializers.CharField(allow_none=True, required=False)
# serializer = CharFieldSerializer(data={'char': None})
# self.assertTrue(serializer.is_valid())
# self.assertIsNone(serializer.object['char'])
entry = EntryModel(bank=1)
serializer = EntrySerializer(entry, data={"bank": 11}) # class SerializerMethodFieldTest(TestCase):
self.assertTrue(serializer.is_valid()) # """
# Tests for the SerializerMethodField field_to_native() behavior
# """
# class SerializerTest(serializers.Serializer):
# def get_my_test(self, obj):
# return obj.my_test[0:5]
# class Example():
# my_test = 'Hey, this is a test !'
serializer = EntrySerializer(entry, data={"bank": -1}) # def test_field_to_native(self):
self.assertFalse(serializer.is_valid()) # s = serializers.SerializerMethodField('get_my_test')
# s.initialize(self.SerializerTest(), 'name')
serializer = EntrySerializer(entry, data={"bank": 101}) # result = s.field_to_native(self.Example(), None)
self.assertFalse(serializer.is_valid()) # self.assertEqual(result, 'Hey, ')
class BooleanField(TestCase):
"""
Tests for BooleanField
"""
def test_boolean_required(self):
class BooleanRequiredSerializer(serializers.Serializer):
bool_field = serializers.BooleanField(required=True)
self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
class ModelCharField(TestCase):
"""
Tests for CharField
"""
def test_none_serializing(self):
class CharFieldSerializer(serializers.Serializer):
char = serializers.CharField(allow_none=True, required=False)
serializer = CharFieldSerializer(data={'char': None})
self.assertTrue(serializer.is_valid())
self.assertIsNone(serializer.object['char'])
class SerializerMethodFieldTest(TestCase):
"""
Tests for the SerializerMethodField field_to_native() behavior
"""
class SerializerTest(serializers.Serializer):
def get_my_test(self, obj):
return obj.my_test[0:5]
class Example():
my_test = 'Hey, this is a test !'
def test_field_to_native(self):
s = serializers.SerializerMethodField('get_my_test')
s.initialize(self.SerializerTest(), 'name')
result = s.field_to_native(self.Example(), None)
self.assertEqual(result, 'Hey, ')
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.test import TestCase # from django.test import TestCase
from django.utils import six # from django.utils import six
from rest_framework import serializers # from rest_framework import serializers
from rest_framework.compat import BytesIO # from rest_framework.compat import BytesIO
import datetime # import datetime
class UploadedFile(object): # class UploadedFile(object):
def __init__(self, file=None, created=None): # def __init__(self, file=None, created=None):
self.file = file # self.file = file
self.created = created or datetime.datetime.now() # self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer): # class UploadedFileSerializer(serializers.Serializer):
file = serializers.FileField(required=False) # file = serializers.FileField(required=False)
created = serializers.DateTimeField() # created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
if instance: # if instance:
instance.file = attrs['file'] # instance.file = attrs['file']
instance.created = attrs['created'] # instance.created = attrs['created']
return instance # return instance
return UploadedFile(**attrs) # return UploadedFile(**attrs)
class FileSerializerTests(TestCase): # class FileSerializerTests(TestCase):
def test_create(self): # def test_create(self):
now = datetime.datetime.now() # now = datetime.datetime.now()
file = BytesIO(six.b('stuff')) # file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt' # file.name = 'stuff.txt'
file.size = len(file.getvalue()) # file.size = len(file.getvalue())
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) # serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now) # uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, uploaded_file.created) # self.assertEqual(serializer.object.created, uploaded_file.created)
self.assertEqual(serializer.object.file, uploaded_file.file) # self.assertEqual(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file) # self.assertFalse(serializer.object is uploaded_file)
def test_creation_failure(self): # def test_creation_failure(self):
""" # """
Passing files=None should result in an ValidationError # Passing files=None should result in an ValidationError
Regression test for: # Regression test for:
https://github.com/tomchristie/django-rest-framework/issues/542 # https://github.com/tomchristie/django-rest-framework/issues/542
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now}) # serializer = UploadedFileSerializer(data={'created': now})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, now) # self.assertEqual(serializer.object.created, now)
self.assertIsNone(serializer.object.file) # self.assertIsNone(serializer.object.file)
def test_remove_with_empty_string(self): # def test_remove_with_empty_string(self):
""" # """
Passing empty string as data should cause file to be removed # Passing empty string as data should cause file to be removed
Test for: # Test for:
https://github.com/tomchristie/django-rest-framework/issues/937 # https://github.com/tomchristie/django-rest-framework/issues/937
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
file = BytesIO(six.b('stuff')) # file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt' # file.name = 'stuff.txt'
file.size = len(file.getvalue()) # file.size = len(file.getvalue())
uploaded_file = UploadedFile(file=file, created=now) # uploaded_file = UploadedFile(file=file, created=now)
serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) # serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, uploaded_file.created) # self.assertEqual(serializer.object.created, uploaded_file.created)
self.assertIsNone(serializer.object.file) # self.assertIsNone(serializer.object.file)
def test_validation_error_with_non_file(self): # def test_validation_error_with_non_file(self):
""" # """
Passing non-files should raise a validation error. # Passing non-files should raise a validation error.
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
errmsg = 'No file was submitted. Check the encoding type on the form.' # errmsg = 'No file was submitted. Check the encoding type on the form.'
serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) # serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'file': [errmsg]}) # self.assertEqual(serializer.errors, {'file': [errmsg]})
def test_validation_with_no_data(self): # def test_validation_with_no_data(self):
""" # """
Validation should still function when no data dictionary is provided. # Validation should still function when no data dictionary is provided.
""" # """
uploaded_file = BytesIO(six.b('stuff')) # uploaded_file = BytesIO(six.b('stuff'))
uploaded_file.name = 'stuff.txt' # uploaded_file.name = 'stuff.txt'
uploaded_file.size = len(uploaded_file.getvalue()) # uploaded_file.size = len(uploaded_file.getvalue())
serializer = UploadedFileSerializer(files={'file': uploaded_file}) # serializer = UploadedFileSerializer(files={'file': uploaded_file})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
...@@ -2,10 +2,11 @@ from __future__ import unicode_literals ...@@ -2,10 +2,11 @@ from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.db import models from django.db import models
from django.conf.urls import patterns, url
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from django.conf.urls import patterns, url from django.utils.dateparse import parse_date
from rest_framework import generics, serializers, status, filters from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
...@@ -16,9 +17,14 @@ factory = APIRequestFactory() ...@@ -16,9 +17,14 @@ factory = APIRequestFactory()
if django_filters: if django_filters:
class FilterableItemSerializer(serializers.ModelSerializer):
class Meta:
model = FilterableItem
# Basic filter on a list view. # Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView): class FilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_fields = ['decimal', 'date'] filter_fields = ['decimal', 'date']
filter_backends = (filters.DjangoFilterBackend,) filter_backends = (filters.DjangoFilterBackend,)
...@@ -33,7 +39,8 @@ if django_filters: ...@@ -33,7 +39,8 @@ if django_filters:
fields = ['text', 'decimal', 'date'] fields = ['text', 'decimal', 'date']
class FilterClassRootView(generics.ListCreateAPIView): class FilterClassRootView(generics.ListCreateAPIView):
model = FilterableItem queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_class = SeveralFieldsFilter filter_class = SeveralFieldsFilter
filter_backends = (filters.DjangoFilterBackend,) filter_backends = (filters.DjangoFilterBackend,)
...@@ -46,12 +53,14 @@ if django_filters: ...@@ -46,12 +53,14 @@ if django_filters:
fields = ['text'] fields = ['text']
class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
model = FilterableItem queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_class = MisconfiguredFilter filter_class = MisconfiguredFilter
filter_backends = (filters.DjangoFilterBackend,) filter_backends = (filters.DjangoFilterBackend,)
class FilterClassDetailView(generics.RetrieveAPIView): class FilterClassDetailView(generics.RetrieveAPIView):
model = FilterableItem queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_class = SeveralFieldsFilter filter_class = SeveralFieldsFilter
filter_backends = (filters.DjangoFilterBackend,) filter_backends = (filters.DjangoFilterBackend,)
...@@ -63,15 +72,12 @@ if django_filters: ...@@ -63,15 +72,12 @@ if django_filters:
model = BaseFilterableItem model = BaseFilterableItem
class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
model = FilterableItem queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_class = BaseFilterableItemFilter filter_class = BaseFilterableItemFilter
filter_backends = (filters.DjangoFilterBackend,) filter_backends = (filters.DjangoFilterBackend,)
# Regression test for #814 # Regression test for #814
class FilterableItemSerializer(serializers.ModelSerializer):
class Meta:
model = FilterableItem
class FilterFieldsQuerysetView(generics.ListCreateAPIView): class FilterFieldsQuerysetView(generics.ListCreateAPIView):
queryset = FilterableItem.objects.all() queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer serializer_class = FilterableItemSerializer
...@@ -97,7 +103,7 @@ if django_filters: ...@@ -97,7 +103,7 @@ if django_filters:
class CommonFilteringTestCase(TestCase): class CommonFilteringTestCase(TestCase):
def _serialize_object(self, obj): def _serialize_object(self, obj):
return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}
def setUp(self): def setUp(self):
""" """
...@@ -140,7 +146,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -140,7 +146,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'decimal': '%s' % search_decimal}) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
# Tests that the date filter works. # Tests that the date filter works.
...@@ -148,7 +154,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -148,7 +154,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date] expected_data = [f for f in self.data if parse_date(f['date']) == search_date]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filter not installed') @unittest.skipUnless(django_filters, 'django-filter not installed')
...@@ -163,7 +169,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -163,7 +169,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'decimal': '%s' % search_decimal}) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filter not installed') @unittest.skipUnless(django_filters, 'django-filter not installed')
...@@ -196,7 +202,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -196,7 +202,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'decimal': '%s' % search_decimal}) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal] expected_data = [f for f in self.data if Decimal(f['decimal']) < search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
# Tests that the date filter set with 'gt' in the filter class works. # Tests that the date filter set with 'gt' in the filter class works.
...@@ -204,7 +210,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -204,7 +210,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date] expected_data = [f for f in self.data if parse_date(f['date']) > search_date]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works. # Tests that the text filter set with 'icontains' in the filter class works.
...@@ -224,8 +230,8 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -224,8 +230,8 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
}) })
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and expected_data = [f for f in self.data if parse_date(f['date']) > search_date and
f['decimal'] < search_decimal] Decimal(f['decimal']) < search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filter not installed') @unittest.skipUnless(django_filters, 'django-filter not installed')
...@@ -323,6 +329,11 @@ class SearchFilterModel(models.Model): ...@@ -323,6 +329,11 @@ class SearchFilterModel(models.Model):
text = models.CharField(max_length=100) text = models.CharField(max_length=100)
class SearchFilterSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModel
class SearchFilterTests(TestCase): class SearchFilterTests(TestCase):
def setUp(self): def setUp(self):
# Sequence of title/text is: # Sequence of title/text is:
...@@ -342,7 +353,8 @@ class SearchFilterTests(TestCase): ...@@ -342,7 +353,8 @@ class SearchFilterTests(TestCase):
def test_search(self): def test_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
model = SearchFilterModel queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text') search_fields = ('title', 'text')
...@@ -359,7 +371,8 @@ class SearchFilterTests(TestCase): ...@@ -359,7 +371,8 @@ class SearchFilterTests(TestCase):
def test_exact_search(self): def test_exact_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
model = SearchFilterModel queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text') search_fields = ('=title', 'text')
...@@ -375,7 +388,8 @@ class SearchFilterTests(TestCase): ...@@ -375,7 +388,8 @@ class SearchFilterTests(TestCase):
def test_startswith_search(self): def test_startswith_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
model = SearchFilterModel queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text') search_fields = ('title', '^text')
...@@ -392,7 +406,8 @@ class SearchFilterTests(TestCase): ...@@ -392,7 +406,8 @@ class SearchFilterTests(TestCase):
def test_search_with_nonstandard_search_param(self): def test_search_with_nonstandard_search_param(self):
with temporary_setting('SEARCH_PARAM', 'query', module=filters): with temporary_setting('SEARCH_PARAM', 'query', module=filters):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
model = SearchFilterModel queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text') search_fields = ('title', 'text')
...@@ -418,6 +433,11 @@ class OrderingFilterRelatedModel(models.Model): ...@@ -418,6 +433,11 @@ class OrderingFilterRelatedModel(models.Model):
related_name="relateds") related_name="relateds")
class OrderingFilterSerializer(serializers.ModelSerializer):
class Meta:
model = OrderingFilterModel
class DjangoFilterOrderingModel(models.Model): class DjangoFilterOrderingModel(models.Model):
date = models.DateField() date = models.DateField()
text = models.CharField(max_length=10) text = models.CharField(max_length=10)
...@@ -426,6 +446,11 @@ class DjangoFilterOrderingModel(models.Model): ...@@ -426,6 +446,11 @@ class DjangoFilterOrderingModel(models.Model):
ordering = ['-date'] ordering = ['-date']
class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
class Meta:
model = DjangoFilterOrderingModel
class DjangoFilterOrderingTests(TestCase): class DjangoFilterOrderingTests(TestCase):
def setUp(self): def setUp(self):
data = [{ data = [{
...@@ -444,7 +469,8 @@ class DjangoFilterOrderingTests(TestCase): ...@@ -444,7 +469,8 @@ class DjangoFilterOrderingTests(TestCase):
def test_default_ordering(self): def test_default_ordering(self):
class DjangoFilterOrderingView(generics.ListAPIView): class DjangoFilterOrderingView(generics.ListAPIView):
model = DjangoFilterOrderingModel serializer_class = DjangoFilterOrderingSerializer
queryset = DjangoFilterOrderingModel.objects.all()
filter_backends = (filters.DjangoFilterBackend,) filter_backends = (filters.DjangoFilterBackend,)
filter_fields = ['text'] filter_fields = ['text']
ordering = ('-date',) ordering = ('-date',)
...@@ -456,9 +482,9 @@ class DjangoFilterOrderingTests(TestCase): ...@@ -456,9 +482,9 @@ class DjangoFilterOrderingTests(TestCase):
self.assertEqual( self.assertEqual(
response.data, response.data,
[ [
{'id': 3, 'date': datetime.date(2014, 10, 8), 'text': 'cde'}, {'id': 3, 'date': '2014-10-08', 'text': 'cde'},
{'id': 2, 'date': datetime.date(2013, 10, 8), 'text': 'bcd'}, {'id': 2, 'date': '2013-10-08', 'text': 'bcd'},
{'id': 1, 'date': datetime.date(2012, 10, 8), 'text': 'abc'} {'id': 1, 'date': '2012-10-08', 'text': 'abc'}
] ]
) )
...@@ -485,7 +511,8 @@ class OrderingFilterTests(TestCase): ...@@ -485,7 +511,8 @@ class OrderingFilterTests(TestCase):
def test_ordering(self): def test_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
model = OrderingFilterModel queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
...@@ -504,7 +531,8 @@ class OrderingFilterTests(TestCase): ...@@ -504,7 +531,8 @@ class OrderingFilterTests(TestCase):
def test_reverse_ordering(self): def test_reverse_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
model = OrderingFilterModel queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
...@@ -523,7 +551,8 @@ class OrderingFilterTests(TestCase): ...@@ -523,7 +551,8 @@ class OrderingFilterTests(TestCase):
def test_incorrectfield_ordering(self): def test_incorrectfield_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
model = OrderingFilterModel queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
...@@ -542,7 +571,8 @@ class OrderingFilterTests(TestCase): ...@@ -542,7 +571,8 @@ class OrderingFilterTests(TestCase):
def test_default_ordering(self): def test_default_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
model = OrderingFilterModel queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
oredering_fields = ('text',) oredering_fields = ('text',)
...@@ -561,7 +591,8 @@ class OrderingFilterTests(TestCase): ...@@ -561,7 +591,8 @@ class OrderingFilterTests(TestCase):
def test_default_ordering_using_string(self): def test_default_ordering_using_string(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
model = OrderingFilterModel queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = 'title'
ordering_fields = ('text',) ordering_fields = ('text',)
...@@ -590,7 +621,7 @@ class OrderingFilterTests(TestCase): ...@@ -590,7 +621,7 @@ class OrderingFilterTests(TestCase):
new_related.save() new_related.save()
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
model = OrderingFilterModel serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = 'title'
ordering_fields = '__all__' ordering_fields = '__all__'
...@@ -612,7 +643,8 @@ class OrderingFilterTests(TestCase): ...@@ -612,7 +643,8 @@ class OrderingFilterTests(TestCase):
def test_ordering_with_nonstandard_ordering_param(self): def test_ordering_with_nonstandard_ordering_param(self):
with temporary_setting('ORDERING_PARAM', 'order', filters): with temporary_setting('ORDERING_PARAM', 'order', filters):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
model = OrderingFilterModel queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
......
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.contrib.contenttypes.models import ContentType # from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey # from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from rest_framework.compat import python_2_unicode_compatible # from rest_framework.compat import python_2_unicode_compatible
@python_2_unicode_compatible # @python_2_unicode_compatible
class Tag(models.Model): # class Tag(models.Model):
""" # """
Tags have a descriptive slug, and are attached to an arbitrary object. # Tags have a descriptive slug, and are attached to an arbitrary object.
""" # """
tag = models.SlugField() # tag = models.SlugField()
content_type = models.ForeignKey(ContentType) # content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField() # object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id') # tagged_item = GenericForeignKey('content_type', 'object_id')
def __str__(self): # def __str__(self):
return self.tag # return self.tag
@python_2_unicode_compatible # @python_2_unicode_compatible
class Bookmark(models.Model): # class Bookmark(models.Model):
""" # """
A URL bookmark that may have multiple tags attached. # A URL bookmark that may have multiple tags attached.
""" # """
url = models.URLField() # url = models.URLField()
tags = GenericRelation(Tag) # tags = GenericRelation(Tag)
def __str__(self): # def __str__(self):
return 'Bookmark: %s' % self.url # return 'Bookmark: %s' % self.url
@python_2_unicode_compatible # @python_2_unicode_compatible
class Note(models.Model): # class Note(models.Model):
""" # """
A textual note that may have multiple tags attached. # A textual note that may have multiple tags attached.
""" # """
text = models.TextField() # text = models.TextField()
tags = GenericRelation(Tag) # tags = GenericRelation(Tag)
def __str__(self): # def __str__(self):
return 'Note: %s' % self.text # return 'Note: %s' % self.text
class TestGenericRelations(TestCase): # class TestGenericRelations(TestCase):
def setUp(self): # def setUp(self):
self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') # self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
Tag.objects.create(tagged_item=self.bookmark, tag='django') # Tag.objects.create(tagged_item=self.bookmark, tag='django')
Tag.objects.create(tagged_item=self.bookmark, tag='python') # Tag.objects.create(tagged_item=self.bookmark, tag='python')
self.note = Note.objects.create(text='Remember the milk') # self.note = Note.objects.create(text='Remember the milk')
Tag.objects.create(tagged_item=self.note, tag='reminder') # Tag.objects.create(tagged_item=self.note, tag='reminder')
def test_generic_relation(self): # def test_generic_relation(self):
""" # """
Test a relationship that spans a GenericRelation field. # Test a relationship that spans a GenericRelation field.
IE. A reverse generic relationship. # IE. A reverse generic relationship.
""" # """
class BookmarkSerializer(serializers.ModelSerializer): # class BookmarkSerializer(serializers.ModelSerializer):
tags = serializers.RelatedField(many=True) # tags = serializers.RelatedField(many=True)
class Meta: # class Meta:
model = Bookmark # model = Bookmark
exclude = ('id',) # exclude = ('id',)
serializer = BookmarkSerializer(self.bookmark) # serializer = BookmarkSerializer(self.bookmark)
expected = { # expected = {
'tags': ['django', 'python'], # 'tags': ['django', 'python'],
'url': 'https://www.djangoproject.com/' # 'url': 'https://www.djangoproject.com/'
} # }
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_generic_nested_relation(self): # def test_generic_nested_relation(self):
""" # """
Test saving a GenericRelation field via a nested serializer. # Test saving a GenericRelation field via a nested serializer.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('content_type', 'object_id') # exclude = ('content_type', 'object_id')
class BookmarkSerializer(serializers.ModelSerializer): # class BookmarkSerializer(serializers.ModelSerializer):
tags = TagSerializer(many=True) # tags = TagSerializer(many=True)
class Meta: # class Meta:
model = Bookmark # model = Bookmark
exclude = ('id',) # exclude = ('id',)
data = { # data = {
'url': 'https://docs.djangoproject.com/', # 'url': 'https://docs.djangoproject.com/',
'tags': [ # 'tags': [
{'tag': 'contenttypes'}, # {'tag': 'contenttypes'},
{'tag': 'genericrelations'}, # {'tag': 'genericrelations'},
] # ]
} # }
serializer = BookmarkSerializer(data=data) # serializer = BookmarkSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
self.assertEqual(serializer.object.tags.count(), 2) # self.assertEqual(serializer.object.tags.count(), 2)
def test_generic_fk(self): # def test_generic_fk(self):
""" # """
Test a relationship that spans a GenericForeignKey field. # Test a relationship that spans a GenericForeignKey field.
IE. A forward generic relationship. # IE. A forward generic relationship.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
tagged_item = serializers.RelatedField() # tagged_item = serializers.RelatedField()
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('id', 'content_type', 'object_id') # exclude = ('id', 'content_type', 'object_id')
serializer = TagSerializer(Tag.objects.all(), many=True) # serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [ # expected = [
{ # {
'tag': 'django', # 'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' # 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, # },
{ # {
'tag': 'python', # 'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' # 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, # },
{ # {
'tag': 'reminder', # 'tag': 'reminder',
'tagged_item': 'Note: Remember the milk' # 'tagged_item': 'Note: Remember the milk'
} # }
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_restore_object_generic_fk(self): # def test_restore_object_generic_fk(self):
""" # """
Ensure an object with a generic foreign key can be restored. # Ensure an object with a generic foreign key can be restored.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('content_type', 'object_id') # exclude = ('content_type', 'object_id')
serializer = TagSerializer() # serializer = TagSerializer()
bookmark = Bookmark(url='http://example.com') # bookmark = Bookmark(url='http://example.com')
attrs = {'tagged_item': bookmark, 'tag': 'example'} # attrs = {'tagged_item': bookmark, 'tag': 'example'}
tag = serializer.restore_object(attrs) # tag = serializer.restore_object(attrs)
self.assertEqual(tag.tagged_item, bookmark) # self.assertEqual(tag.tagged_item, bookmark)
from __future__ import unicode_literals from __future__ import unicode_literals
import django
from django.db import models from django.db import models
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.test import TestCase from django.test import TestCase
...@@ -11,44 +12,53 @@ from tests.models import ForeignKeySource, ForeignKeyTarget ...@@ -11,44 +12,53 @@ from tests.models import ForeignKeySource, ForeignKeyTarget
factory = APIRequestFactory() factory = APIRequestFactory()
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
class ForeignKeySerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
""" """
Example description for OPTIONS. Example description for OPTIONS.
""" """
model = BasicModel queryset = BasicModel.objects.all()
serializer_class = BasicSerializer
class InstanceView(generics.RetrieveUpdateDestroyAPIView): class InstanceView(generics.RetrieveUpdateDestroyAPIView):
""" """
Example description for OPTIONS. Example description for OPTIONS.
""" """
model = BasicModel queryset = BasicModel.objects.exclude(text='filtered out')
serializer_class = BasicSerializer
def get_queryset(self):
queryset = super(InstanceView, self).get_queryset()
return queryset.exclude(text='filtered out')
class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): class FKInstanceView(generics.RetrieveUpdateDestroyAPIView):
""" """
FK: example description for OPTIONS. FK: example description for OPTIONS.
""" """
model = ForeignKeySource queryset = ForeignKeySource.objects.all()
serializer_class = ForeignKeySerializer
class SlugSerializer(serializers.ModelSerializer): class SlugSerializer(serializers.ModelSerializer):
slug = serializers.Field() # read only slug = serializers.Field(read_only=True)
class Meta: class Meta:
model = SlugBasedModel model = SlugBasedModel
exclude = ('id',) fields = ('text', 'slug')
class SlugBasedInstanceView(InstanceView): class SlugBasedInstanceView(InstanceView):
""" """
A model with a slug-field. A model with a slug-field.
""" """
model = SlugBasedModel queryset = SlugBasedModel.objects.all()
serializer_class = SlugSerializer serializer_class = SlugSerializer
lookup_field = 'slug' lookup_field = 'slug'
...@@ -112,46 +122,46 @@ class TestRootView(TestCase): ...@@ -112,46 +122,46 @@ class TestRootView(TestCase):
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
def test_options_root_view(self): # def test_options_root_view(self):
""" # """
OPTIONS requests to ListCreateAPIView should return metadata # OPTIONS requests to ListCreateAPIView should return metadata
""" # """
request = factory.options('/') # request = factory.options('/')
with self.assertNumQueries(0): # with self.assertNumQueries(0):
response = self.view(request).render() # response = self.view(request).render()
expected = { # expected = {
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'name': 'Root', # 'name': 'Root',
'description': 'Example description for OPTIONS.', # 'description': 'Example description for OPTIONS.',
'actions': { # 'actions': {
'POST': { # 'POST': {
'text': { # 'text': {
'max_length': 100, # 'max_length': 100,
'read_only': False, # 'read_only': False,
'required': True, # 'required': True,
'type': 'string', # 'type': 'string',
"label": "Text comes here", # "label": "Text comes here",
"help_text": "Text description." # "help_text": "Text description."
}, # },
'id': { # 'id': {
'read_only': True, # 'read_only': True,
'required': False, # 'required': False,
'type': 'integer', # 'type': 'integer',
'label': 'ID', # 'label': 'ID',
}, # },
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
def test_post_cannot_set_id(self): def test_post_cannot_set_id(self):
""" """
...@@ -167,6 +177,9 @@ class TestRootView(TestCase): ...@@ -167,6 +177,9 @@ class TestRootView(TestCase):
self.assertEqual(created.text, 'foobar') self.assertEqual(created.text, 'foobar')
EXPECTED_QUERYS_FOR_PUT = 3 if django.VERSION < (1, 6) else 2
class TestInstanceView(TestCase): class TestInstanceView(TestCase):
def setUp(self): def setUp(self):
""" """
...@@ -210,10 +223,10 @@ class TestInstanceView(TestCase): ...@@ -210,10 +223,10 @@ class TestInstanceView(TestCase):
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(2): with self.assertNumQueries(EXPECTED_QUERYS_FOR_PUT):
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(dict(response.data), {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
self.assertEqual(updated.text, 'foobar') self.assertEqual(updated.text, 'foobar')
...@@ -224,7 +237,7 @@ class TestInstanceView(TestCase): ...@@ -224,7 +237,7 @@ class TestInstanceView(TestCase):
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json') request = factory.patch('/1', data, format='json')
with self.assertNumQueries(2): with self.assertNumQueries(EXPECTED_QUERYS_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
...@@ -243,88 +256,88 @@ class TestInstanceView(TestCase): ...@@ -243,88 +256,88 @@ class TestInstanceView(TestCase):
ids = [obj.id for obj in self.objects.all()] ids = [obj.id for obj in self.objects.all()]
self.assertEqual(ids, [2, 3]) self.assertEqual(ids, [2, 3])
def test_options_instance_view(self): # def test_options_instance_view(self):
""" # """
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata # OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
""" # """
request = factory.options('/1') # request = factory.options('/1')
with self.assertNumQueries(1): # with self.assertNumQueries(1):
response = self.view(request, pk=1).render() # response = self.view(request, pk=1).render()
expected = { # expected = {
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'name': 'Instance', # 'name': 'Instance',
'description': 'Example description for OPTIONS.', # 'description': 'Example description for OPTIONS.',
'actions': { # 'actions': {
'PUT': { # 'PUT': {
'text': { # 'text': {
'max_length': 100, # 'max_length': 100,
'read_only': False, # 'read_only': False,
'required': True, # 'required': True,
'type': 'string', # 'type': 'string',
'label': 'Text comes here', # 'label': 'Text comes here',
'help_text': 'Text description.' # 'help_text': 'Text description.'
}, # },
'id': { # 'id': {
'read_only': True, # 'read_only': True,
'required': False, # 'required': False,
'type': 'integer', # 'type': 'integer',
'label': 'ID', # 'label': 'ID',
}, # },
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
def test_options_before_instance_create(self): # def test_options_before_instance_create(self):
""" # """
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata # OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
before the instance has been created # before the instance has been created
""" # """
request = factory.options('/999') # request = factory.options('/999')
with self.assertNumQueries(1): # with self.assertNumQueries(1):
response = self.view(request, pk=999).render() # response = self.view(request, pk=999).render()
expected = { # expected = {
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'name': 'Instance', # 'name': 'Instance',
'description': 'Example description for OPTIONS.', # 'description': 'Example description for OPTIONS.',
'actions': { # 'actions': {
'PUT': { # 'PUT': {
'text': { # 'text': {
'max_length': 100, # 'max_length': 100,
'read_only': False, # 'read_only': False,
'required': True, # 'required': True,
'type': 'string', # 'type': 'string',
'label': 'Text comes here', # 'label': 'Text comes here',
'help_text': 'Text description.' # 'help_text': 'Text description.'
}, # },
'id': { # 'id': {
'read_only': True, # 'read_only': True,
'required': False, # 'required': False,
'type': 'integer', # 'type': 'integer',
'label': 'ID', # 'label': 'ID',
}, # },
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
def test_get_instance_view_incorrect_arg(self): def test_get_instance_view_incorrect_arg(self):
""" """
...@@ -342,7 +355,7 @@ class TestInstanceView(TestCase): ...@@ -342,7 +355,7 @@ class TestInstanceView(TestCase):
""" """
data = {'id': 999, 'text': 'foobar'} data = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(2): with self.assertNumQueries(EXPECTED_QUERYS_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
...@@ -351,18 +364,15 @@ class TestInstanceView(TestCase): ...@@ -351,18 +364,15 @@ class TestInstanceView(TestCase):
def test_put_to_deleted_instance(self): def test_put_to_deleted_instance(self):
""" """
PUT requests to RetrieveUpdateDestroyAPIView should create an object PUT requests to RetrieveUpdateDestroyAPIView should return 404 if
if it does not currently exist. an object does not currently exist.
""" """
self.objects.get(id=1).delete() self.objects.get(id=1).delete()
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(3): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
self.assertEqual(updated.text, 'foobar')
def test_put_to_filtered_out_instance(self): def test_put_to_filtered_out_instance(self):
""" """
...@@ -373,35 +383,7 @@ class TestInstanceView(TestCase): ...@@ -373,35 +383,7 @@ class TestInstanceView(TestCase):
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render() response = self.view(request, pk=filtered_out_pk).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_put_as_create_on_id_based_url(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if it doesn't exist.
"""
data = {'text': 'foobar'}
# pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', data, format='json')
with self.assertNumQueries(3):
response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5)
self.assertEqual(new_obj.text, 'foobar')
def test_put_as_create_on_slug_based_url(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
"""
data = {'text': 'foobar'}
request = factory.put('/test_slug', data, format='json')
with self.assertNumQueries(2):
response = self.slug_based_view(request, slug='test_slug').render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
def test_patch_cannot_create_an_object(self): def test_patch_cannot_create_an_object(self):
""" """
...@@ -433,52 +415,52 @@ class TestFKInstanceView(TestCase): ...@@ -433,52 +415,52 @@ class TestFKInstanceView(TestCase):
] ]
self.view = FKInstanceView.as_view() self.view = FKInstanceView.as_view()
def test_options_root_view(self): # def test_options_root_view(self):
""" # """
OPTIONS requests to ListCreateAPIView should return metadata # OPTIONS requests to ListCreateAPIView should return metadata
""" # """
request = factory.options('/999') # request = factory.options('/999')
with self.assertNumQueries(1): # with self.assertNumQueries(1):
response = self.view(request, pk=999).render() # response = self.view(request, pk=999).render()
expected = { # expected = {
'name': 'Fk Instance', # 'name': 'Fk Instance',
'description': 'FK: example description for OPTIONS.', # 'description': 'FK: example description for OPTIONS.',
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'actions': { # 'actions': {
'PUT': { # 'PUT': {
'id': { # 'id': {
'type': 'integer', # 'type': 'integer',
'required': False, # 'required': False,
'read_only': True, # 'read_only': True,
'label': 'ID' # 'label': 'ID'
}, # },
'name': { # 'name': {
'type': 'string', # 'type': 'string',
'required': True, # 'required': True,
'read_only': False, # 'read_only': False,
'label': 'name', # 'label': 'name',
'max_length': 100 # 'max_length': 100
}, # },
'target': { # 'target': {
'type': 'field', # 'type': 'field',
'required': True, # 'required': True,
'read_only': False, # 'read_only': False,
'label': 'Target', # 'label': 'Target',
'help_text': 'Target' # 'help_text': 'Target'
} # }
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
class TestOverriddenGetObject(TestCase): class TestOverriddenGetObject(TestCase):
...@@ -503,7 +485,7 @@ class TestOverriddenGetObject(TestCase): ...@@ -503,7 +485,7 @@ class TestOverriddenGetObject(TestCase):
""" """
Example detail view for override of get_object(). Example detail view for override of get_object().
""" """
model = BasicModel serializer_class = BasicSerializer
def get_object(self): def get_object(self):
pk = int(self.kwargs['pk']) pk = int(self.kwargs['pk'])
...@@ -565,7 +547,9 @@ class ClassA(models.Model): ...@@ -565,7 +547,9 @@ class ClassA(models.Model):
class ClassASerializer(serializers.ModelSerializer): class ClassASerializer(serializers.ModelSerializer):
childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') childs = serializers.PrimaryKeyRelatedField(
many=True, queryset=ClassB.objects.all()
)
class Meta: class Meta:
model = ClassA model = ClassA
...@@ -573,7 +557,7 @@ class ClassASerializer(serializers.ModelSerializer): ...@@ -573,7 +557,7 @@ class ClassASerializer(serializers.ModelSerializer):
class ExampleView(generics.ListCreateAPIView): class ExampleView(generics.ListCreateAPIView):
serializer_class = ClassASerializer serializer_class = ClassASerializer
model = ClassA queryset = ClassA.objects.all()
class TestM2MBrowseableAPI(TestCase): class TestM2MBrowseableAPI(TestCase):
...@@ -603,7 +587,7 @@ class TwoFieldModel(models.Model): ...@@ -603,7 +587,7 @@ class TwoFieldModel(models.Model):
class DynamicSerializerView(generics.ListCreateAPIView): class DynamicSerializerView(generics.ListCreateAPIView):
model = TwoFieldModel queryset = TwoFieldModel.objects.all()
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get_serializer_class(self): def get_serializer_class(self):
...@@ -612,8 +596,11 @@ class DynamicSerializerView(generics.ListCreateAPIView): ...@@ -612,8 +596,11 @@ class DynamicSerializerView(generics.ListCreateAPIView):
class Meta: class Meta:
model = TwoFieldModel model = TwoFieldModel
fields = ('field_b',) fields = ('field_b',)
else:
class DynamicSerializer(serializers.ModelSerializer):
class Meta:
model = TwoFieldModel
return DynamicSerializer return DynamicSerializer
return super(DynamicSerializerView, self).get_serializer_class()
class TestFilterBackendAppliedToViews(TestCase): class TestFilterBackendAppliedToViews(TestCase):
......
from __future__ import unicode_literals # from __future__ import unicode_literals
import json # import json
from django.test import TestCase # from django.test import TestCase
from rest_framework import generics, status, serializers # from rest_framework import generics, status, serializers
from django.conf.urls import patterns, url # from django.conf.urls import patterns, url
from rest_framework.settings import api_settings # from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory # from rest_framework.test import APIRequestFactory
from tests.models import ( # from tests.models import (
Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, # Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
Album, Photo, OptionalRelationModel # Album, Photo, OptionalRelationModel
) # )
factory = APIRequestFactory() # factory = APIRequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer): # class BlogPostCommentSerializer(serializers.ModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') # url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
text = serializers.CharField() # text = serializers.CharField()
blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') # blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
class Meta: # class Meta:
model = BlogPostComment # model = BlogPostComment
fields = ('text', 'blog_post_url', 'url') # fields = ('text', 'blog_post_url', 'url')
class PhotoSerializer(serializers.Serializer): # class PhotoSerializer(serializers.Serializer):
description = serializers.CharField() # description = serializers.CharField()
album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title') # album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title')
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
return Photo(**attrs) # return Photo(**attrs)
class AlbumSerializer(serializers.ModelSerializer): # class AlbumSerializer(serializers.ModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') # url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
class Meta: # class Meta:
model = Album # model = Album
fields = ('title', 'url') # fields = ('title', 'url')
class BasicList(generics.ListCreateAPIView): # class BasicSerializer(serializers.HyperlinkedModelSerializer):
model = BasicModel # class Meta:
model_serializer_class = serializers.HyperlinkedModelSerializer # model = BasicModel
class BasicDetail(generics.RetrieveUpdateDestroyAPIView): # class AnchorSerializer(serializers.HyperlinkedModelSerializer):
model = BasicModel # class Meta:
model_serializer_class = serializers.HyperlinkedModelSerializer # model = Anchor
class AnchorDetail(generics.RetrieveAPIView): # class ManyToManySerializer(serializers.HyperlinkedModelSerializer):
model = Anchor # class Meta:
model_serializer_class = serializers.HyperlinkedModelSerializer # model = ManyToManyModel
class ManyToManyList(generics.ListAPIView): # class BlogPostSerializer(serializers.ModelSerializer):
model = ManyToManyModel # class Meta:
model_serializer_class = serializers.HyperlinkedModelSerializer # model = BlogPost
class ManyToManyDetail(generics.RetrieveAPIView): # class OptionalRelationSerializer(serializers.HyperlinkedModelSerializer):
model = ManyToManyModel # class Meta:
model_serializer_class = serializers.HyperlinkedModelSerializer # model = OptionalRelationModel
class BlogPostCommentListCreate(generics.ListCreateAPIView): # class BasicList(generics.ListCreateAPIView):
model = BlogPostComment # queryset = BasicModel.objects.all()
serializer_class = BlogPostCommentSerializer # serializer_class = BasicSerializer
class BlogPostCommentDetail(generics.RetrieveAPIView): # class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
model = BlogPostComment # queryset = BasicModel.objects.all()
serializer_class = BlogPostCommentSerializer # serializer_class = BasicSerializer
class BlogPostDetail(generics.RetrieveAPIView): # class AnchorDetail(generics.RetrieveAPIView):
model = BlogPost # queryset = Anchor.objects.all()
# serializer_class = AnchorSerializer
class PhotoListCreate(generics.ListCreateAPIView): # class ManyToManyList(generics.ListAPIView):
model = Photo # queryset = ManyToManyModel.objects.all()
model_serializer_class = PhotoSerializer # serializer_class = ManyToManySerializer
class AlbumDetail(generics.RetrieveAPIView): # class ManyToManyDetail(generics.RetrieveAPIView):
model = Album # queryset = ManyToManyModel.objects.all()
serializer_class = AlbumSerializer # serializer_class = ManyToManySerializer
lookup_field = 'title'
# class BlogPostCommentListCreate(generics.ListCreateAPIView):
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): # queryset = BlogPostComment.objects.all()
model = OptionalRelationModel # serializer_class = BlogPostCommentSerializer
model_serializer_class = serializers.HyperlinkedModelSerializer
# class BlogPostCommentDetail(generics.RetrieveAPIView):
urlpatterns = patterns( # queryset = BlogPostComment.objects.all()
'', # serializer_class = BlogPostCommentSerializer
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), # class BlogPostDetail(generics.RetrieveAPIView):
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), # queryset = BlogPost.objects.all()
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), # serializer_class = BlogPostSerializer
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), # class PhotoListCreate(generics.ListCreateAPIView):
url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), # queryset = Photo.objects.all()
url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), # serializer_class = PhotoSerializer
url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
)
# class AlbumDetail(generics.RetrieveAPIView):
# queryset = Album.objects.all()
class TestBasicHyperlinkedView(TestCase): # serializer_class = AlbumSerializer
urls = 'tests.test_hyperlinkedserializers' # lookup_field = 'title'
def setUp(self):
""" # class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
Create 3 BasicModel instances. # queryset = OptionalRelationModel.objects.all()
""" # serializer_class = OptionalRelationSerializer
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save() # urlpatterns = patterns(
self.objects = BasicModel.objects # '',
self.data = [ # url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
{'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text} # url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
for obj in self.objects.all() # url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
] # url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
self.list_view = BasicList.as_view() # url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
self.detail_view = BasicDetail.as_view() # url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
# url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
def test_get_list_view(self): # url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
""" # url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
GET requests to ListCreateAPIView should return list of objects. # url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
""" # url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
request = factory.get('/basic/') # )
response = self.list_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data) # class TestBasicHyperlinkedView(TestCase):
# urls = 'tests.test_hyperlinkedserializers'
def test_get_detail_view(self):
""" # def setUp(self):
GET requests to ListCreateAPIView should return list of objects. # """
""" # Create 3 BasicModel instances.
request = factory.get('/basic/1') # """
response = self.detail_view(request, pk=1).render() # items = ['foo', 'bar', 'baz']
self.assertEqual(response.status_code, status.HTTP_200_OK) # for item in items:
self.assertEqual(response.data, self.data[0]) # BasicModel(text=item).save()
# self.objects = BasicModel.objects
# self.data = [
class TestManyToManyHyperlinkedView(TestCase): # {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
urls = 'tests.test_hyperlinkedserializers' # for obj in self.objects.all()
# ]
def setUp(self): # self.list_view = BasicList.as_view()
""" # self.detail_view = BasicDetail.as_view()
Create 3 BasicModel instances.
""" # def test_get_list_view(self):
items = ['foo', 'bar', 'baz'] # """
anchors = [] # GET requests to ListCreateAPIView should return list of objects.
for item in items: # """
anchor = Anchor(text=item) # request = factory.get('/basic/')
anchor.save() # response = self.list_view(request).render()
anchors.append(anchor) # self.assertEqual(response.status_code, status.HTTP_200_OK)
# self.assertEqual(response.data, self.data)
manytomany = ManyToManyModel()
manytomany.save() # def test_get_detail_view(self):
manytomany.rel.add(*anchors) # """
# GET requests to ListCreateAPIView should return list of objects.
self.data = [{ # """
'url': 'http://testserver/manytomany/1/', # request = factory.get('/basic/1')
'rel': [ # response = self.detail_view(request, pk=1).render()
'http://testserver/anchor/1/', # self.assertEqual(response.status_code, status.HTTP_200_OK)
'http://testserver/anchor/2/', # self.assertEqual(response.data, self.data[0])
'http://testserver/anchor/3/',
]
}] # class TestManyToManyHyperlinkedView(TestCase):
self.list_view = ManyToManyList.as_view() # urls = 'tests.test_hyperlinkedserializers'
self.detail_view = ManyToManyDetail.as_view()
# def setUp(self):
def test_get_list_view(self): # """
""" # Create 3 BasicModel instances.
GET requests to ListCreateAPIView should return list of objects. # """
""" # items = ['foo', 'bar', 'baz']
request = factory.get('/manytomany/') # anchors = []
response = self.list_view(request) # for item in items:
self.assertEqual(response.status_code, status.HTTP_200_OK) # anchor = Anchor(text=item)
self.assertEqual(response.data, self.data) # anchor.save()
# anchors.append(anchor)
def test_get_detail_view(self):
""" # manytomany = ManyToManyModel()
GET requests to ListCreateAPIView should return list of objects. # manytomany.save()
""" # manytomany.rel.add(*anchors)
request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1) # self.data = [{
self.assertEqual(response.status_code, status.HTTP_200_OK) # 'url': 'http://testserver/manytomany/1/',
self.assertEqual(response.data, self.data[0]) # 'rel': [
# 'http://testserver/anchor/1/',
# 'http://testserver/anchor/2/',
class TestHyperlinkedIdentityFieldLookup(TestCase): # 'http://testserver/anchor/3/',
urls = 'tests.test_hyperlinkedserializers' # ]
# }]
def setUp(self): # self.list_view = ManyToManyList.as_view()
""" # self.detail_view = ManyToManyDetail.as_view()
Create 3 Album instances.
""" # def test_get_list_view(self):
titles = ['foo', 'bar', 'baz'] # """
for title in titles: # GET requests to ListCreateAPIView should return list of objects.
album = Album(title=title) # """
album.save() # request = factory.get('/manytomany/')
self.detail_view = AlbumDetail.as_view() # response = self.list_view(request)
self.data = { # self.assertEqual(response.status_code, status.HTTP_200_OK)
'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, # self.assertEqual(response.data, self.data)
'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} # def test_get_detail_view(self):
} # """
# GET requests to ListCreateAPIView should return list of objects.
def test_lookup_field(self): # """
""" # request = factory.get('/manytomany/1/')
GET requests to AlbumDetail view should return serialized Albums # response = self.detail_view(request, pk=1)
with a url field keyed by `title`. # self.assertEqual(response.status_code, status.HTTP_200_OK)
""" # self.assertEqual(response.data, self.data[0])
for album in Album.objects.all():
request = factory.get('/albums/{0}/'.format(album.title))
response = self.detail_view(request, title=album.title) # class TestHyperlinkedIdentityFieldLookup(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) # urls = 'tests.test_hyperlinkedserializers'
self.assertEqual(response.data, self.data[album.title])
# def setUp(self):
# """
class TestCreateWithForeignKeys(TestCase): # Create 3 Album instances.
urls = 'tests.test_hyperlinkedserializers' # """
# titles = ['foo', 'bar', 'baz']
def setUp(self): # for title in titles:
""" # album = Album(title=title)
Create a blog post # album.save()
""" # self.detail_view = AlbumDetail.as_view()
self.post = BlogPost.objects.create(title="Test post") # self.data = {
self.create_view = BlogPostCommentListCreate.as_view() # 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
# 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
def test_create_comment(self): # 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
# }
data = {
'text': 'A test comment', # def test_lookup_field(self):
'blog_post_url': 'http://testserver/posts/1/' # """
} # GET requests to AlbumDetail view should return serialized Albums
# with a url field keyed by `title`.
request = factory.post('/comments/', data=data) # """
response = self.create_view(request) # for album in Album.objects.all():
self.assertEqual(response.status_code, status.HTTP_201_CREATED) # request = factory.get('/albums/{0}/'.format(album.title))
self.assertEqual(response['Location'], 'http://testserver/comments/1/') # response = self.detail_view(request, title=album.title)
self.assertEqual(self.post.blogpostcomment_set.count(), 1) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') # self.assertEqual(response.data, self.data[album.title])
class TestCreateWithForeignKeysAndCustomSlug(TestCase): # class TestCreateWithForeignKeys(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create an Album # Create a blog post
""" # """
self.post = Album.objects.create(title='test-album') # self.post = BlogPost.objects.create(title="Test post")
self.list_create_view = PhotoListCreate.as_view() # self.create_view = BlogPostCommentListCreate.as_view()
def test_create_photo(self): # def test_create_comment(self):
data = { # data = {
'description': 'A test photo', # 'text': 'A test comment',
'album_url': 'http://testserver/albums/test-album/' # 'blog_post_url': 'http://testserver/posts/1/'
} # }
request = factory.post('/photos/', data=data) # request = factory.post('/comments/', data=data)
response = self.list_create_view(request) # response = self.create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) # self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') # self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.photo_set.count(), 1) # self.assertEqual(self.post.blogpostcomment_set.count(), 1)
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') # self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
class TestOptionalRelationHyperlinkedView(TestCase): # class TestCreateWithForeignKeysAndCustomSlug(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create 1 OptionalRelationModel instances. # Create an Album
""" # """
OptionalRelationModel().save() # self.post = Album.objects.create(title='test-album')
self.objects = OptionalRelationModel.objects # self.list_create_view = PhotoListCreate.as_view()
self.detail_view = OptionalRelationDetail.as_view()
self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} # def test_create_photo(self):
def test_get_detail_view(self): # data = {
""" # 'description': 'A test photo',
GET requests to RetrieveAPIView with optional relations should return None # 'album_url': 'http://testserver/albums/test-album/'
for non existing relations. # }
"""
request = factory.get('/optionalrelationmodel-detail/1') # request = factory.post('/photos/', data=data)
response = self.detail_view(request, pk=1) # response = self.list_create_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, self.data) # self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
# self.assertEqual(self.post.photo_set.count(), 1)
def test_put_detail_view(self): # self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
"""
PUT requests to RetrieveUpdateDestroyAPIView with optional relations
should accept None for non existing relations. # class TestOptionalRelationHyperlinkedView(TestCase):
""" # urls = 'tests.test_hyperlinkedserializers'
response = self.client.put('/optionalrelation/1/',
data=json.dumps(self.data), # def setUp(self):
content_type='application/json') # """
self.assertEqual(response.status_code, status.HTTP_200_OK) # Create 1 OptionalRelationModel instances.
# """
# OptionalRelationModel().save()
class TestOverriddenURLField(TestCase): # self.objects = OptionalRelationModel.objects
def setUp(self): # self.detail_view = OptionalRelationDetail.as_view()
class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer): # self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
url = serializers.SerializerMethodField('get_url')
# def test_get_detail_view(self):
class Meta: # """
model = BlogPost # GET requests to RetrieveAPIView with optional relations should return None
fields = ('title', 'url') # for non existing relations.
# """
def get_url(self, obj): # request = factory.get('/optionalrelationmodel-detail/1')
return 'foo bar' # response = self.detail_view(request, pk=1)
# self.assertEqual(response.status_code, status.HTTP_200_OK)
self.Serializer = OverriddenURLSerializer # self.assertEqual(response.data, self.data)
self.obj = BlogPost.objects.create(title='New blog post')
# def test_put_detail_view(self):
def test_overridden_url_field(self): # """
""" # PUT requests to RetrieveUpdateDestroyAPIView with optional relations
The 'url' field should respect overriding. # should accept None for non existing relations.
Regression test for #936. # """
""" # response = self.client.put('/optionalrelation/1/',
serializer = self.Serializer(self.obj) # data=json.dumps(self.data),
self.assertEqual( # content_type='application/json')
serializer.data, # self.assertEqual(response.status_code, status.HTTP_200_OK)
{'title': 'New blog post', 'url': 'foo bar'}
)
# class TestOverriddenURLField(TestCase):
# def setUp(self):
class TestURLFieldNameBySettings(TestCase): # class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
urls = 'tests.test_hyperlinkedserializers' # url = serializers.SerializerMethodField('get_url')
def setUp(self): # class Meta:
self.saved_url_field_name = api_settings.URL_FIELD_NAME # model = BlogPost
api_settings.URL_FIELD_NAME = 'global_url_field' # fields = ('title', 'url')
class Serializer(serializers.HyperlinkedModelSerializer): # def get_url(self, obj):
# return 'foo bar'
class Meta:
model = BlogPost # self.Serializer = OverriddenURLSerializer
fields = ('title', api_settings.URL_FIELD_NAME) # self.obj = BlogPost.objects.create(title='New blog post')
self.Serializer = Serializer # def test_overridden_url_field(self):
self.obj = BlogPost.objects.create(title="New blog post") # """
# The 'url' field should respect overriding.
def tearDown(self): # Regression test for #936.
api_settings.URL_FIELD_NAME = self.saved_url_field_name # """
# serializer = self.Serializer(self.obj)
def test_overridden_url_field_name(self): # self.assertEqual(
request = factory.get('/posts/') # serializer.data,
serializer = self.Serializer(self.obj, context={'request': request}) # {'title': 'New blog post', 'url': 'foo bar'}
self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) # )
class TestURLFieldNameByOptions(TestCase): # class TestURLFieldNameBySettings(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
class Serializer(serializers.HyperlinkedModelSerializer): # self.saved_url_field_name = api_settings.URL_FIELD_NAME
# api_settings.URL_FIELD_NAME = 'global_url_field'
class Meta:
model = BlogPost # class Serializer(serializers.HyperlinkedModelSerializer):
fields = ('title', 'serializer_url_field')
url_field_name = 'serializer_url_field' # class Meta:
# model = BlogPost
self.Serializer = Serializer # fields = ('title', api_settings.URL_FIELD_NAME)
self.obj = BlogPost.objects.create(title="New blog post")
# self.Serializer = Serializer
def test_overridden_url_field_name(self): # self.obj = BlogPost.objects.create(title="New blog post")
request = factory.get('/posts/')
serializer = self.Serializer(self.obj, context={'request': request}) # def tearDown(self):
self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) # api_settings.URL_FIELD_NAME = self.saved_url_field_name
# def test_overridden_url_field_name(self):
# request = factory.get('/posts/')
# serializer = self.Serializer(self.obj, context={'request': request})
# self.assertIn(api_settings.URL_FIELD_NAME, serializer.data)
# class TestURLFieldNameByOptions(TestCase):
# urls = 'tests.test_hyperlinkedserializers'
# def setUp(self):
# class Serializer(serializers.HyperlinkedModelSerializer):
# class Meta:
# model = BlogPost
# fields = ('title', 'serializer_url_field')
# url_field_name = 'serializer_url_field'
# self.Serializer = Serializer
# self.obj = BlogPost.objects.create(title="New blog post")
# def test_overridden_url_field_name(self):
# request = factory.get('/posts/')
# serializer = self.Serializer(self.obj, context={'request': request})
# self.assertIn(self.Serializer.Meta.url_field_name, serializer.data)
"""
The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially
shortcuts for automatically creating serializers based on a given model class.
These tests deal with ensuring that we correctly map the model fields onto
an appropriate set of serializer fields for each case.
"""
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.test import TestCase
from rest_framework import serializers
def dedent(blocktext):
return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
# Testing regular field mappings
class CustomField(models.Field):
"""
A custom model field simply for testing purposes.
"""
pass
class RegularFieldsModel(models.Model):
"""
A model class for testing regular flat fields.
"""
auto_field = models.AutoField(primary_key=True)
big_integer_field = models.BigIntegerField()
boolean_field = models.BooleanField(default=False)
char_field = models.CharField(max_length=100)
comma_seperated_integer_field = models.CommaSeparatedIntegerField(max_length=100)
date_field = models.DateField()
datetime_field = models.DateTimeField()
decimal_field = models.DecimalField(max_digits=3, decimal_places=1)
email_field = models.EmailField(max_length=100)
float_field = models.FloatField()
integer_field = models.IntegerField()
null_boolean_field = models.NullBooleanField()
positive_integer_field = models.PositiveIntegerField()
positive_small_integer_field = models.PositiveSmallIntegerField()
slug_field = models.SlugField(max_length=100)
small_integer_field = models.SmallIntegerField()
text_field = models.TextField()
time_field = models.TimeField()
url_field = models.URLField(max_length=100)
custom_field = CustomField()
def method(self):
return 'method'
class TestRegularFieldMappings(TestCase):
def test_regular_fields(self):
"""
Model fields should map to their equivelent serializer fields.
"""
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
expected = dedent("""
TestSerializer():
auto_field = IntegerField(read_only=True)
big_integer_field = IntegerField()
boolean_field = BooleanField(default=False)
char_field = CharField(max_length=100)
comma_seperated_integer_field = CharField(max_length=100, validators=[<django.core.validators.RegexValidator object>])
date_field = DateField()
datetime_field = DateTimeField()
decimal_field = DecimalField(decimal_places=1, max_digits=3)
email_field = EmailField(max_length=100)
float_field = FloatField()
integer_field = IntegerField()
null_boolean_field = BooleanField(required=False)
positive_integer_field = IntegerField()
positive_small_integer_field = IntegerField()
slug_field = SlugField(max_length=100)
small_integer_field = IntegerField()
text_field = CharField()
time_field = TimeField()
url_field = URLField(max_length=100)
custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>)
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_method_field(self):
"""
Properties and methods on the model should be allowed as `Meta.fields`
values, and should map to `ReadOnlyField`.
"""
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
fields = ('auto_field', 'method')
expected = dedent("""
TestSerializer():
auto_field = IntegerField(read_only=True)
method = ReadOnlyField()
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_pk_fields(self):
"""
Both `pk` and the actual primary key name are valid in `Meta.fields`.
"""
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
fields = ('pk', 'auto_field')
expected = dedent("""
TestSerializer():
pk = IntegerField(label='Auto field', read_only=True)
auto_field = IntegerField(read_only=True)
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_extra_field_kwargs(self):
"""
Ensure `extra_kwargs` are passed to generated fields.
"""
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
fields = ('auto_field', 'char_field')
extra_kwargs = {'char_field': {'default': 'extra'}}
expected = dedent("""
TestSerializer():
auto_field = IntegerField(read_only=True)
char_field = CharField(default='extra', max_length=100)
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_invalid_field(self):
"""
Field names that do not map to a model field or relationship should
raise a configuration errror.
"""
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
fields = ('auto_field', 'invalid')
with self.assertRaises(ImproperlyConfigured) as excinfo:
TestSerializer()
expected = 'Field name `invalid` is not valid for model `ModelBase`.'
assert str(excinfo.exception) == expected
def test_missing_field(self):
"""
Fields that have been declared on the serializer class must be included
in the `Meta.fields` if it exists.
"""
class TestSerializer(serializers.ModelSerializer):
missing = serializers.ReadOnlyField()
class Meta:
model = RegularFieldsModel
fields = ('auto_field',)
with self.assertRaises(ImproperlyConfigured) as excinfo:
TestSerializer()
expected = (
'Field `missing` has been declared on serializer '
'`TestSerializer`, but is missing from `Meta.fields`.'
)
assert str(excinfo.exception) == expected
# Testing relational field mappings
class ForeignKeyTargetModel(models.Model):
name = models.CharField(max_length=100)
class ManyToManyTargetModel(models.Model):
name = models.CharField(max_length=100)
class OneToOneTargetModel(models.Model):
name = models.CharField(max_length=100)
class ThroughTargetModel(models.Model):
name = models.CharField(max_length=100)
class Supplementary(models.Model):
extra = models.IntegerField()
forwards = models.ForeignKey('ThroughTargetModel')
backwards = models.ForeignKey('RelationalModel')
class RelationalModel(models.Model):
foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='reverse_foreign_key')
many_to_many = models.ManyToManyField(ManyToManyTargetModel, related_name='reverse_many_to_many')
one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one')
through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through')
class TestRelationalFieldMappings(TestCase):
def test_pk_relations(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all())
one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all())
many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_nested_relations(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
depth = 1
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
one_to_one = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
many_to_many = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
through = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
expected = dedent("""
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail')
one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), view_name='onetoonetargetmodel-detail')
many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_nested_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
depth = 1
expected = dedent("""
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
name = CharField(max_length=100)
one_to_one = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100)
many_to_many = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
name = CharField(max_length=100)
through = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_pk_reverse_foreign_key(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeyTargetModel
fields = ('id', 'name', 'reverse_foreign_key')
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_pk_reverse_one_to_one(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = OneToOneTargetModel
fields = ('id', 'name', 'reverse_one_to_one')
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_pk_reverse_many_to_many(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManyTargetModel
fields = ('id', 'name', 'reverse_many_to_many')
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_pk_reverse_through(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ThroughTargetModel
fields = ('id', 'name', 'reverse_through')
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
self.assertEqual(repr(TestSerializer()), expected)
class TestIntegration(TestCase):
def setUp(self):
self.foreign_key_target = ForeignKeyTargetModel.objects.create(
name='foreign_key'
)
self.one_to_one_target = OneToOneTargetModel.objects.create(
name='one_to_one'
)
self.many_to_many_targets = [
ManyToManyTargetModel.objects.create(
name='many_to_many (%d)' % idx
) for idx in range(3)
]
self.instance = RelationalModel.objects.create(
foreign_key=self.foreign_key_target,
one_to_one=self.one_to_one_target,
)
self.instance.many_to_many = self.many_to_many_targets
self.instance.save()
def test_pk_retrival(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
serializer = TestSerializer(self.instance)
expected = {
'id': self.instance.pk,
'foreign_key': self.foreign_key_target.pk,
'one_to_one': self.one_to_one_target.pk,
'many_to_many': [item.pk for item in self.many_to_many_targets],
'through': []
}
self.assertEqual(serializer.data, expected)
def test_pk_create(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
new_foreign_key = ForeignKeyTargetModel.objects.create(
name='foreign_key'
)
new_one_to_one = OneToOneTargetModel.objects.create(
name='one_to_one'
)
new_many_to_many = [
ManyToManyTargetModel.objects.create(
name='new many_to_many (%d)' % idx
) for idx in range(3)
]
data = {
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
serializer = TestSerializer(data=data)
assert serializer.is_valid()
# Creating the instance, relationship attributes should be set.
instance = serializer.save()
assert instance.foreign_key.pk == new_foreign_key.pk
assert instance.one_to_one.pk == new_one_to_one.pk
assert [
item.pk for item in instance.many_to_many.all()
] == [
item.pk for item in new_many_to_many
]
assert list(instance.through.all()) == []
# Representation should be correct.
expected = {
'id': instance.pk,
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'through': []
}
self.assertEqual(serializer.data, expected)
def test_pk_update(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
new_foreign_key = ForeignKeyTargetModel.objects.create(
name='foreign_key'
)
new_one_to_one = OneToOneTargetModel.objects.create(
name='one_to_one'
)
new_many_to_many = [
ManyToManyTargetModel.objects.create(
name='new many_to_many (%d)' % idx
) for idx in range(3)
]
data = {
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
serializer = TestSerializer(self.instance, data=data)
assert serializer.is_valid()
# Creating the instance, relationship attributes should be set.
instance = serializer.save()
assert instance.foreign_key.pk == new_foreign_key.pk
assert instance.one_to_one.pk == new_one_to_one.pk
assert [
item.pk for item in instance.many_to_many.all()
] == [
item.pk for item in new_many_to_many
]
assert list(instance.through.all()) == []
# Representation should be correct.
expected = {
'id': self.instance.pk,
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'through': []
}
self.assertEqual(serializer.data, expected)
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from rest_framework.serializers import _resolve_model from rest_framework.utils.model_meta import _resolve_model
from tests.models import BasicModel from tests.models import BasicModel
......
from django.core.urlresolvers import reverse # from django.core.urlresolvers import reverse
from django.conf.urls import patterns, url # from django.conf.urls import patterns, url
from rest_framework.test import APITestCase # from rest_framework import serializers, generics
from tests.models import NullableForeignKeySource # from rest_framework.test import APITestCase
from tests.serializers import NullableFKSourceSerializer # from tests.models import NullableForeignKeySource
from tests.views import NullableFKSourceDetail
urlpatterns = patterns( # class NullableFKSourceSerializer(serializers.ModelSerializer):
'', # class Meta:
url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), # model = NullableForeignKeySource
)
class NullableForeignKeyTests(APITestCase): # class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
""" # queryset = NullableForeignKeySource.objects.all()
DRF should be able to handle nullable foreign keys when a test # serializer_class = NullableFKSourceSerializer
Client POST/PUT request is made with its own serialized object.
"""
urls = 'tests.test_nullable_fields'
def test_updating_object_with_null_fk(self):
obj = NullableForeignKeySource(name='example', target=None)
obj.save()
serialized_data = NullableFKSourceSerializer(obj).data
response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) # urlpatterns = patterns(
# '',
# url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
# )
self.assertEqual(response.data, serialized_data)
# class NullableForeignKeyTests(APITestCase):
# """
# DRF should be able to handle nullable foreign keys when a test
# Client POST/PUT request is made with its own serialized object.
# """
# urls = 'tests.test_nullable_fields'
# def test_updating_object_with_null_fk(self):
# obj = NullableForeignKeySource(name='example', target=None)
# obj.save()
# serialized_data = NullableFKSourceSerializer(obj).data
# response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
# self.assertEqual(response.data, serialized_data)
...@@ -4,7 +4,7 @@ from decimal import Decimal ...@@ -4,7 +4,7 @@ from decimal import Decimal
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers from rest_framework import generics, serializers, status, pagination, filters
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from .models import BasicModel, FilterableItem from .models import BasicModel, FilterableItem
...@@ -22,11 +22,22 @@ def split_arguments_from_url(url): ...@@ -22,11 +22,22 @@ def split_arguments_from_url(url):
return path, args return path, args
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
class FilterableItemSerializer(serializers.ModelSerializer):
class Meta:
model = FilterableItem
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
""" """
Example description for OPTIONS. Example description for OPTIONS.
""" """
model = BasicModel queryset = BasicModel.objects.all()
serializer_class = BasicSerializer
paginate_by = 10 paginate_by = 10
...@@ -34,14 +45,16 @@ class DefaultPageSizeKwargView(generics.ListAPIView): ...@@ -34,14 +45,16 @@ class DefaultPageSizeKwargView(generics.ListAPIView):
""" """
View for testing default paginate_by_param usage View for testing default paginate_by_param usage
""" """
model = BasicModel queryset = BasicModel.objects.all()
serializer_class = BasicSerializer
class PaginateByParamView(generics.ListAPIView): class PaginateByParamView(generics.ListAPIView):
""" """
View for testing custom paginate_by_param usage View for testing custom paginate_by_param usage
""" """
model = BasicModel queryset = BasicModel.objects.all()
serializer_class = BasicSerializer
paginate_by_param = 'page_size' paginate_by_param = 'page_size'
...@@ -49,7 +62,8 @@ class MaxPaginateByView(generics.ListAPIView): ...@@ -49,7 +62,8 @@ class MaxPaginateByView(generics.ListAPIView):
""" """
View for testing custom max_paginate_by usage View for testing custom max_paginate_by usage
""" """
model = BasicModel queryset = BasicModel.objects.all()
serializer_class = BasicSerializer
paginate_by = 3 paginate_by = 3
max_paginate_by = 5 max_paginate_by = 5
paginate_by_param = 'page_size' paginate_by_param = 'page_size'
...@@ -121,7 +135,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -121,7 +135,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects self.objects = FilterableItem.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}
for obj in self.objects.all() for obj in self.objects.all()
] ]
...@@ -140,7 +154,8 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -140,7 +154,8 @@ class IntegrationTestPaginationAndFiltering(TestCase):
fields = ['text', 'decimal', 'date'] fields = ['text', 'decimal', 'date']
class FilterFieldsRootView(generics.ListCreateAPIView): class FilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
paginate_by = 10 paginate_by = 10
filter_class = DecimalFilter filter_class = DecimalFilter
filter_backends = (filters.DjangoFilterBackend,) filter_backends = (filters.DjangoFilterBackend,)
...@@ -188,7 +203,8 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -188,7 +203,8 @@ class IntegrationTestPaginationAndFiltering(TestCase):
return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
class BasicFilterFieldsRootView(generics.ListCreateAPIView): class BasicFilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
paginate_by = 10 paginate_by = 10
filter_backends = (DecimalFilterBackend,) filter_backends = (DecimalFilterBackend,)
...@@ -365,7 +381,7 @@ class TestMaxPaginateByParam(TestCase): ...@@ -365,7 +381,7 @@ class TestMaxPaginateByParam(TestCase):
# Tests for context in pagination serializers # Tests for context in pagination serializers
class CustomField(serializers.Field): class CustomField(serializers.ReadOnlyField):
def to_native(self, value): def to_native(self, value):
if 'view' not in self.context: if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into custom field") raise RuntimeError("context isn't getting passed into custom field")
...@@ -375,10 +391,10 @@ class CustomField(serializers.Field): ...@@ -375,10 +391,10 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer): class BasicModelSerializer(serializers.Serializer):
text = CustomField() text = CustomField()
def __init__(self, *args, **kwargs): def to_native(self, value):
super(BasicModelSerializer, self).__init__(*args, **kwargs)
if 'view' not in self.context: if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into serializer init") raise RuntimeError("context isn't getting passed into serializer")
return super(BasicSerializer, self).to_native(value)
class TestContextPassedToCustomField(TestCase): class TestContextPassedToCustomField(TestCase):
...@@ -387,7 +403,7 @@ class TestContextPassedToCustomField(TestCase): ...@@ -387,7 +403,7 @@ class TestContextPassedToCustomField(TestCase):
def test_with_pagination(self): def test_with_pagination(self):
class ListView(generics.ListCreateAPIView): class ListView(generics.ListCreateAPIView):
model = BasicModel queryset = BasicModel.objects.all()
serializer_class = BasicModelSerializer serializer_class = BasicModelSerializer
paginate_by = 1 paginate_by = 1
...@@ -407,7 +423,7 @@ class LinksSerializer(serializers.Serializer): ...@@ -407,7 +423,7 @@ class LinksSerializer(serializers.Serializer):
class CustomPaginationSerializer(pagination.BasePaginationSerializer): class CustomPaginationSerializer(pagination.BasePaginationSerializer):
links = LinksSerializer(source='*') # Takes the page object as the source links = LinksSerializer(source='*') # Takes the page object as the source
total_results = serializers.Field(source='paginator.count') total_results = serializers.ReadOnlyField(source='paginator.count')
results_field = 'objects' results_field = 'objects'
......
# -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from rest_framework.compat import StringIO from rest_framework.compat import StringIO
from django import forms from django import forms
...@@ -113,3 +115,25 @@ class TestFileUploadParser(TestCase): ...@@ -113,3 +115,25 @@ class TestFileUploadParser(TestCase):
parser = FileUploadParser() parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
self.assertEqual(filename, 'file.txt') self.assertEqual(filename, 'file.txt')
def test_get_encoded_filename(self):
parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
self.assertEqual(filename, 'ÀĥƦ.txt')
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
self.assertEqual(filename, 'ÀĥƦ.txt')
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
self.assertEqual(filename, 'ÀĥƦ.txt')
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8--ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
self.assertEqual(filename, 'fallback.txt')
def __replace_content_disposition(self, disposition):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
...@@ -3,7 +3,7 @@ from django.contrib.auth.models import User, Permission, Group ...@@ -3,7 +3,7 @@ from django.contrib.auth.models import User, Permission, Group
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING from rest_framework import generics, serializers, status, permissions, authentication, HTTP_HEADER_ENCODING
from rest_framework.compat import guardian, get_model_name from rest_framework.compat import guardian, get_model_name
from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.filters import DjangoObjectPermissionsFilter
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
...@@ -13,14 +13,21 @@ import base64 ...@@ -13,14 +13,21 @@ import base64
factory = APIRequestFactory() factory = APIRequestFactory()
class RootView(generics.ListCreateAPIView): class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel model = BasicModel
class RootView(generics.ListCreateAPIView):
queryset = BasicModel.objects.all()
serializer_class = BasicSerializer
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions] permission_classes = [permissions.DjangoModelPermissions]
class InstanceView(generics.RetrieveUpdateDestroyAPIView): class InstanceView(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel queryset = BasicModel.objects.all()
serializer_class = BasicSerializer
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions] permission_classes = [permissions.DjangoModelPermissions]
...@@ -88,72 +95,59 @@ class ModelPermissionsIntegrationTests(TestCase): ...@@ -88,72 +95,59 @@ class ModelPermissionsIntegrationTests(TestCase):
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_has_put_as_create_permissions(self): # def test_options_permitted(self):
# User only has update permissions - should be able to update an entity. # request = factory.options(
request = factory.put('/1', {'text': 'foobar'}, format='json', # '/',
HTTP_AUTHORIZATION=self.updateonly_credentials) # HTTP_AUTHORIZATION=self.permitted_credentials
response = instance_view(request, pk='1') # )
self.assertEqual(response.status_code, status.HTTP_200_OK) # response = root_view(request, pk='1')
# self.assertEqual(response.status_code, status.HTTP_200_OK)
# But if PUTing to a new entity, permission should be denied. # self.assertIn('actions', response.data)
request = factory.put('/2', {'text': 'foobar'}, format='json', # self.assertEqual(list(response.data['actions'].keys()), ['POST'])
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='2') # request = factory.options(
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) # '/1',
# HTTP_AUTHORIZATION=self.permitted_credentials
def test_options_permitted(self): # )
request = factory.options( # response = instance_view(request, pk='1')
'/', # self.assertEqual(response.status_code, status.HTTP_200_OK)
HTTP_AUTHORIZATION=self.permitted_credentials # self.assertIn('actions', response.data)
) # self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # def test_options_disallowed(self):
self.assertIn('actions', response.data) # request = factory.options(
self.assertEqual(list(response.data['actions'].keys()), ['POST']) # '/',
# HTTP_AUTHORIZATION=self.disallowed_credentials
request = factory.options( # )
'/1', # response = root_view(request, pk='1')
HTTP_AUTHORIZATION=self.permitted_credentials # self.assertEqual(response.status_code, status.HTTP_200_OK)
) # self.assertNotIn('actions', response.data)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # request = factory.options(
self.assertIn('actions', response.data) # '/1',
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) # HTTP_AUTHORIZATION=self.disallowed_credentials
# )
def test_options_disallowed(self): # response = instance_view(request, pk='1')
request = factory.options( # self.assertEqual(response.status_code, status.HTTP_200_OK)
'/', # self.assertNotIn('actions', response.data)
HTTP_AUTHORIZATION=self.disallowed_credentials
) # def test_options_updateonly(self):
response = root_view(request, pk='1') # request = factory.options(
self.assertEqual(response.status_code, status.HTTP_200_OK) # '/',
self.assertNotIn('actions', response.data) # HTTP_AUTHORIZATION=self.updateonly_credentials
# )
request = factory.options( # response = root_view(request, pk='1')
'/1', # self.assertEqual(response.status_code, status.HTTP_200_OK)
HTTP_AUTHORIZATION=self.disallowed_credentials # self.assertNotIn('actions', response.data)
)
response = instance_view(request, pk='1') # request = factory.options(
self.assertEqual(response.status_code, status.HTTP_200_OK) # '/1',
self.assertNotIn('actions', response.data) # HTTP_AUTHORIZATION=self.updateonly_credentials
# )
def test_options_updateonly(self): # response = instance_view(request, pk='1')
request = factory.options( # self.assertEqual(response.status_code, status.HTTP_200_OK)
'/', # self.assertIn('actions', response.data)
HTTP_AUTHORIZATION=self.updateonly_credentials # self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
class BasicPermModel(models.Model): class BasicPermModel(models.Model):
...@@ -167,6 +161,11 @@ class BasicPermModel(models.Model): ...@@ -167,6 +161,11 @@ class BasicPermModel(models.Model):
) )
class BasicPermSerializer(serializers.ModelSerializer):
class Meta:
model = BasicPermModel
# Custom object-level permission, that includes 'view' permissions # Custom object-level permission, that includes 'view' permissions
class ViewObjectPermissions(permissions.DjangoObjectPermissions): class ViewObjectPermissions(permissions.DjangoObjectPermissions):
perms_map = { perms_map = {
...@@ -181,7 +180,8 @@ class ViewObjectPermissions(permissions.DjangoObjectPermissions): ...@@ -181,7 +180,8 @@ class ViewObjectPermissions(permissions.DjangoObjectPermissions):
class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView): class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
model = BasicPermModel queryset = BasicPermModel.objects.all()
serializer_class = BasicPermSerializer
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions] permission_classes = [ViewObjectPermissions]
...@@ -189,7 +189,8 @@ object_permissions_view = ObjectPermissionInstanceView.as_view() ...@@ -189,7 +189,8 @@ object_permissions_view = ObjectPermissionInstanceView.as_view()
class ObjectPermissionListView(generics.ListAPIView): class ObjectPermissionListView(generics.ListAPIView):
model = BasicPermModel queryset = BasicPermModel.objects.all()
serializer_class = BasicPermSerializer
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions] permission_classes = [ViewObjectPermissions]
......
""" from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset
General tests for relational fields. from django.core.exceptions import ImproperlyConfigured, ValidationError
"""
from __future__ import unicode_literals
from django import get_version
from django.db import models
from django.test import TestCase
from django.utils import unittest
from rest_framework import serializers from rest_framework import serializers
from tests.models import BlogPost from rest_framework.test import APISimpleTestCase
import pytest
class NullModel(models.Model):
pass class TestStringRelatedField(APISimpleTestCase):
def setUp(self):
self.instance = MockObject(pk=1, name='foo')
class FieldTests(TestCase): self.field = serializers.StringRelatedField()
def test_pk_related_field_with_empty_string(self):
def test_string_related_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == '<MockObject name=foo, pk=1>'
class TestPrimaryKeyRelatedField(APISimpleTestCase):
def setUp(self):
self.queryset = MockQueryset([
MockObject(pk=1, name='foo'),
MockObject(pk=2, name='bar'),
MockObject(pk=3, name='baz')
])
self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset)
def test_pk_related_lookup_exists(self):
instance = self.field.to_internal_value(self.instance.pk)
assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self):
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(4)
msg = excinfo.value.messages[0]
assert msg == "Invalid pk '4' - object does not exist."
def test_pk_related_lookup_invalid_type(self):
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.messages[0]
assert msg == 'Incorrect type. Expected pk value, received BadType.'
def test_pk_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == self.instance.pk
class TestHyperlinkedIdentityField(APISimpleTestCase):
def setUp(self):
self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example')
self.field.reverse = mock_reverse
self.field.context = {'request': True}
def test_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1/'
def test_representation_unsaved_object(self):
representation = self.field.to_representation(MockObject(pk=None))
assert representation is None
def test_representation_with_format(self):
self.field.context['format'] = 'xml'
representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1.xml/'
def test_improperly_configured(self):
""" """
Regression test for #446 If a matching view cannot be reversed with the given instance,
the the user has misconfigured something, as the URL conf and the
https://github.com/tomchristie/django-rest-framework/issues/446 hyperlinked field do not match.
""" """
field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) self.field.reverse = fail_reverse
self.assertRaises(serializers.ValidationError, field.from_native, '') with pytest.raises(ImproperlyConfigured):
self.assertRaises(serializers.ValidationError, field.from_native, []) self.field.to_representation(self.instance)
def test_hyperlinked_related_field_with_empty_string(self):
field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, [])
def test_slug_related_field_with_empty_string(self):
field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, [])
class TestManyRelatedMixin(TestCase):
def test_missing_many_to_many_related_field(self):
'''
Regression test for #632
https://github.com/tomchristie/django-rest-framework/pull/632
'''
field = serializers.RelatedField(many=True, read_only=False)
into = {}
field.field_from_native({}, None, 'field_name', into)
self.assertEqual(into['field_name'], [])
# Regression tests for #694 (`source` attribute on related fields)
class RelatedFieldSourceTests(TestCase):
def test_related_manager_source(self):
"""
Relational fields should be able to use manager-returning methods as their source.
"""
BlogPost.objects.create(title='blah')
field = serializers.RelatedField(many=True, source='get_blogposts_manager')
class ClassWithManagerMethod(object):
def get_blogposts_manager(self):
return BlogPost.objects
obj = ClassWithManagerMethod()
value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['BlogPost object'])
def test_related_queryset_source(self): class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase):
""" """
Relational fields should be able to use queryset-returning methods as their source. Tests for a hyperlinked identity field that has a `format` set,
""" which enforces that alternate formats are never linked too.
BlogPost.objects.create(title='blah')
field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
class ClassWithQuerysetMethod(object):
def get_blogposts_queryset(self):
return BlogPost.objects.all()
obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['BlogPost object'])
def test_dotted_source(self): Eg. If your API includes some endpoints that accept both `.xml` and `.json`,
""" but other endpoints that only accept `.json`, we allow for hyperlinked
Source argument should support dotted.source notation. relationships that enforce only a single suffix type.
"""
BlogPost.objects.create(title='blah')
field = serializers.RelatedField(many=True, source='a.b.c')
class ClassWithQuerysetMethod(object):
a = {
'b': {
'c': BlogPost.objects.all()
}
}
obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['BlogPost object'])
# Regression for #1129
def test_exception_for_incorect_fk(self):
""" """
Check that the exception message are correct if the source field
doesn't exist.
"""
from tests.models import ManyToManySource
class Meta:
model = ManyToManySource
attrs = { def setUp(self):
'name': serializers.SlugRelatedField( self.instance = MockObject(pk=1, name='foo')
slug_field='name', source='banzai'), self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json')
'Meta': Meta, self.field.reverse = mock_reverse
} self.field.context = {'request': True}
TestSerializer = type( def test_representation(self):
str('TestSerializer'), representation = self.field.to_representation(self.instance)
(serializers.ModelSerializer,), assert representation == 'http://example.org/example/1/'
attrs
def test_representation_with_format(self):
self.field.context['format'] = 'xml'
representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1.json/'
class TestSlugRelatedField(APISimpleTestCase):
def setUp(self):
self.queryset = MockQueryset([
MockObject(pk=1, name='foo'),
MockObject(pk=2, name='bar'),
MockObject(pk=3, name='baz')
])
self.instance = self.queryset.items[2]
self.field = serializers.SlugRelatedField(
slug_field='name', queryset=self.queryset
) )
with self.assertRaises(AttributeError):
TestSerializer(data={'name': 'foo'})
@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
class RelatedFieldChoicesTests(TestCase):
"""
Tests for #1408 "Web browseable API doesn't have blank option on drop down list box"
https://github.com/tomchristie/django-rest-framework/issues/1408
"""
def test_blank_option_is_added_to_choice_if_required_equals_false(self):
"""
"""
post = BlogPost(title="Checking blank option is added")
post.save()
queryset = BlogPost.objects.all()
field = serializers.RelatedField(required=False, queryset=queryset)
choice_count = BlogPost.objects.count()
widget_count = len(field.widget.choices)
self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') def test_slug_related_lookup_exists(self):
instance = self.field.to_internal_value(self.instance.name)
assert instance is self.instance
def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value('doesnotexist')
msg = excinfo.value.messages[0]
assert msg == 'Object with name=doesnotexist does not exist.'
def test_slug_related_lookup_invalid_type(self):
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.messages[0]
assert msg == 'Invalid value.'
def test_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == self.instance.name
# Older tests, for review...
# """
# General tests for relational fields.
# """
# from __future__ import unicode_literals
# from django import get_version
# from django.db import models
# from django.test import TestCase
# from django.utils import unittest
# from rest_framework import serializers
# from tests.models import BlogPost
# class NullModel(models.Model):
# pass
# class FieldTests(TestCase):
# def test_pk_related_field_with_empty_string(self):
# """
# Regression test for #446
# https://github.com/tomchristie/django-rest-framework/issues/446
# """
# field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_hyperlinked_related_field_with_empty_string(self):
# field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_slug_related_field_with_empty_string(self):
# field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# class TestManyRelatedMixin(TestCase):
# def test_missing_many_to_many_related_field(self):
# '''
# Regression test for #632
# https://github.com/tomchristie/django-rest-framework/pull/632
# '''
# field = serializers.RelatedField(many=True, read_only=False)
# into = {}
# field.field_from_native({}, None, 'field_name', into)
# self.assertEqual(into['field_name'], [])
# # Regression tests for #694 (`source` attribute on related fields)
# class RelatedFieldSourceTests(TestCase):
# def test_related_manager_source(self):
# """
# Relational fields should be able to use manager-returning methods as their source.
# """
# BlogPost.objects.create(title='blah')
# field = serializers.RelatedField(many=True, source='get_blogposts_manager')
# class ClassWithManagerMethod(object):
# def get_blogposts_manager(self):
# return BlogPost.objects
# obj = ClassWithManagerMethod()
# value = field.field_to_native(obj, 'field_name')
# self.assertEqual(value, ['BlogPost object'])
# def test_related_queryset_source(self):
# """
# Relational fields should be able to use queryset-returning methods as their source.
# """
# BlogPost.objects.create(title='blah')
# field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
# class ClassWithQuerysetMethod(object):
# def get_blogposts_queryset(self):
# return BlogPost.objects.all()
# obj = ClassWithQuerysetMethod()
# value = field.field_to_native(obj, 'field_name')
# self.assertEqual(value, ['BlogPost object'])
# def test_dotted_source(self):
# """
# Source argument should support dotted.source notation.
# """
# BlogPost.objects.create(title='blah')
# field = serializers.RelatedField(many=True, source='a.b.c')
# class ClassWithQuerysetMethod(object):
# a = {
# 'b': {
# 'c': BlogPost.objects.all()
# }
# }
# obj = ClassWithQuerysetMethod()
# value = field.field_to_native(obj, 'field_name')
# self.assertEqual(value, ['BlogPost object'])
# # Regression for #1129
# def test_exception_for_incorect_fk(self):
# """
# Check that the exception message are correct if the source field
# doesn't exist.
# """
# from tests.models import ManyToManySource
# class Meta:
# model = ManyToManySource
# attrs = {
# 'name': serializers.SlugRelatedField(
# slug_field='name', source='banzai'),
# 'Meta': Meta,
# }
# TestSerializer = type(
# str('TestSerializer'),
# (serializers.ModelSerializer,),
# attrs
# )
# with self.assertRaises(AttributeError):
# TestSerializer(data={'name': 'foo'})
# @unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
# class RelatedFieldChoicesTests(TestCase):
# """
# Tests for #1408 "Web browseable API doesn't have blank option on drop down list box"
# https://github.com/tomchristie/django-rest-framework/issues/1408
# """
# def test_blank_option_is_added_to_choice_if_required_equals_false(self):
# """
# """
# post = BlogPost(title="Checking blank option is added")
# post.save()
# queryset = BlogPost.objects.all()
# field = serializers.RelatedField(required=False, queryset=queryset)
# choice_count = BlogPost.objects.count()
# widget_count = len(field.widget.choices)
# self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added')
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.conf.urls import patterns, url # from django.conf.urls import patterns, url
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from rest_framework.test import APIRequestFactory # from rest_framework.test import APIRequestFactory
from tests.models import ( # from tests.models import (
BlogPost, # BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, # ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource # NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
) # )
factory = APIRequestFactory() # factory = APIRequestFactory()
request = factory.get('/') # Just to ensure we have a request in the serializer context # request = factory.get('/') # Just to ensure we have a request in the serializer context
def dummy_view(request, pk): # def dummy_view(request, pk):
pass # pass
urlpatterns = patterns( # urlpatterns = patterns(
'', # '',
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), # url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), # url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), # url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), # url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), # url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), # url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), # url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), # url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
) # )
# ManyToMany # # ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): # class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = ManyToManyTarget # model = ManyToManyTarget
fields = ('url', 'name', 'sources') # fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): # class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = ManyToManySource # model = ManyToManySource
fields = ('url', 'name', 'targets') # fields = ('url', 'name', 'targets')
# ForeignKey # # ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): # class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = ForeignKeyTarget # model = ForeignKeyTarget
fields = ('url', 'name', 'sources') # fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): # class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = ForeignKeySource # model = ForeignKeySource
fields = ('url', 'name', 'target') # fields = ('url', 'name', 'target')
# Nullable ForeignKey # # Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): # class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = NullableForeignKeySource # model = NullableForeignKeySource
fields = ('url', 'name', 'target') # fields = ('url', 'name', 'target')
# Nullable OneToOne # # Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): # class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = OneToOneTarget # model = OneToOneTarget
fields = ('url', 'name', 'nullable_source') # fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid # # TODO: Add test that .data cannot be accessed prior to .is_valid
class HyperlinkedManyToManyTests(TestCase): # class HyperlinkedManyToManyTests(TestCase):
urls = 'tests.test_relations_hyperlink' # urls = 'tests.test_relations_hyperlink'
def setUp(self): # def setUp(self):
for idx in range(1, 4): # for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) # target = ManyToManyTarget(name='target-%d' % idx)
target.save() # target.save()
source = ManyToManySource(name='source-%d' % idx) # source = ManyToManySource(name='source-%d' % idx)
source.save() # source.save()
for target in ManyToManyTarget.objects.all(): # for target in ManyToManyTarget.objects.all():
source.targets.add(target) # source.targets.add(target)
def test_many_to_many_retrieve(self): # def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() # queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) # serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, # {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, # {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} # {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self): # def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() # queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) # serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, # {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, # {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} # {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self): # def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} # data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1) # instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) # serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all() # queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) # serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, # {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, # {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} # {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self): # def test_reverse_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} # data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1) # instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) # serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected # # Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all() # queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) # serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, # {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, # {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} # {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self): # def test_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} # data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data, context={'request': request}) # serializer = ManyToManySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # # Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() # queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) # serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, # {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, # {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, # {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} # {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self): # def test_reverse_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} # data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) # serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-4') # self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected # # Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() # queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) # serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, # {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, # {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, # {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} # {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
class HyperlinkedForeignKeyTests(TestCase): # class HyperlinkedForeignKeyTests(TestCase):
urls = 'tests.test_relations_hyperlink' # urls = 'tests.test_relations_hyperlink'
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
new_target = ForeignKeyTarget(name='target-2') # new_target = ForeignKeyTarget(name='target-2')
new_target.save() # new_target.save()
for idx in range(1, 4): # for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target) # source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve(self): # def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} # {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self): # def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) # serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, # {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, # {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self): # def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} # data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) # serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, # {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} # {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self): # def test_foreign_key_update_incorrect_type(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} # data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) # serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']}) # self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
def test_reverse_foreign_key_update(self): # def test_reverse_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} # data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2) # instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) # serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save # # We shouldn't have saved anything to the db yet since save
# hasn't been called. # # hasn't been called.
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) # new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, # {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, # {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(new_serializer.data, expected) # self.assertEqual(new_serializer.data, expected)
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected # # Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) # serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, # {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, # {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self): # def test_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} # data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) # serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, # {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self): # def test_reverse_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} # data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) # serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3') # self.assertEqual(obj.name, 'target-3')
# Ensure target 4 is added, and everything else is as expected # # Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) # serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, # {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, # {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, # {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): # def test_foreign_key_update_with_invalid_null(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} # data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) # serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field is required.']}) # self.assertEqual(serializer.errors, {'target': ['This field is required.']})
class HyperlinkedNullableForeignKeyTests(TestCase): # class HyperlinkedNullableForeignKeyTests(TestCase):
urls = 'tests.test_relations_hyperlink' # urls = 'tests.test_relations_hyperlink'
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
for idx in range(1, 4): # for idx in range(1, 4):
if idx == 3: # if idx == 3:
target = None # target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) # source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve_with_null(self): # def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, # {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): # def test_foreign_key_create_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} # data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, # {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} # {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self): # def test_foreign_key_create_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} # data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} # expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, # {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} # {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): # def test_foreign_key_update_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} # data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, # {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, # {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self): # def test_foreign_key_update_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} # data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} # expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) # serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, # {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, # {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, # {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
# reverse foreign keys MUST be read_only # # reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear() # # In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set. # # and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self): # # def test_reverse_foreign_key_update(self):
# data = {'id': 1, 'name': 'target-1', 'sources': [1]} # # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1) # # instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data) # # serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid()) # # self.assertTrue(serializer.is_valid())
# self.assertEqual(serializer.data, data) # # self.assertEqual(serializer.data, data)
# serializer.save() # # serializer.save()
# # Ensure target 1 is updated, and everything else is as expected # # # Ensure target 1 is updated, and everything else is as expected
# queryset = ForeignKeyTarget.objects.all() # # queryset = ForeignKeyTarget.objects.all()
# serializer = ForeignKeyTargetSerializer(queryset, many=True) # # serializer = ForeignKeyTargetSerializer(queryset, many=True)
# expected = [ # # expected = [
# {'id': 1, 'name': 'target-1', 'sources': [1]}, # # {'id': 1, 'name': 'target-1', 'sources': [1]},
# {'id': 2, 'name': 'target-2', 'sources': []}, # # {'id': 2, 'name': 'target-2', 'sources': []},
# ] # # ]
# self.assertEqual(serializer.data, expected) # # self.assertEqual(serializer.data, expected)
class HyperlinkedNullableOneToOneTests(TestCase): # class HyperlinkedNullableOneToOneTests(TestCase):
urls = 'tests.test_relations_hyperlink' # urls = 'tests.test_relations_hyperlink'
def setUp(self): # def setUp(self):
target = OneToOneTarget(name='target-1') # target = OneToOneTarget(name='target-1')
target.save() # target.save()
new_target = OneToOneTarget(name='target-2') # new_target = OneToOneTarget(name='target-2')
new_target.save() # new_target.save()
source = NullableOneToOneSource(name='source-1', target=target) # source = NullableOneToOneSource(name='source-1', target=target)
source.save() # source.save()
def test_reverse_foreign_key_retrieve_with_null(self): # def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) # serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [ # expected = [
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, # {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, # {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
# Regression tests for #694 (`source` attribute on related fields) # # Regression tests for #694 (`source` attribute on related fields)
class HyperlinkedRelatedFieldSourceTests(TestCase): # class HyperlinkedRelatedFieldSourceTests(TestCase):
urls = 'tests.test_relations_hyperlink' # urls = 'tests.test_relations_hyperlink'
def test_related_manager_source(self): # def test_related_manager_source(self):
""" # """
Relational fields should be able to use manager-returning methods as their source. # Relational fields should be able to use manager-returning methods as their source.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.HyperlinkedRelatedField( # field = serializers.HyperlinkedRelatedField(
many=True, # many=True,
source='get_blogposts_manager', # source='get_blogposts_manager',
view_name='dummy-url', # view_name='dummy-url',
) # )
field.context = {'request': request} # field.context = {'request': request}
class ClassWithManagerMethod(object): # class ClassWithManagerMethod(object):
def get_blogposts_manager(self): # def get_blogposts_manager(self):
return BlogPost.objects # return BlogPost.objects
obj = ClassWithManagerMethod() # obj = ClassWithManagerMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['http://testserver/dummyurl/1/']) # self.assertEqual(value, ['http://testserver/dummyurl/1/'])
def test_related_queryset_source(self): # def test_related_queryset_source(self):
""" # """
Relational fields should be able to use queryset-returning methods as their source. # Relational fields should be able to use queryset-returning methods as their source.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.HyperlinkedRelatedField( # field = serializers.HyperlinkedRelatedField(
many=True, # many=True,
source='get_blogposts_queryset', # source='get_blogposts_queryset',
view_name='dummy-url', # view_name='dummy-url',
) # )
field.context = {'request': request} # field.context = {'request': request}
class ClassWithQuerysetMethod(object): # class ClassWithQuerysetMethod(object):
def get_blogposts_queryset(self): # def get_blogposts_queryset(self):
return BlogPost.objects.all() # return BlogPost.objects.all()
obj = ClassWithQuerysetMethod() # obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['http://testserver/dummyurl/1/']) # self.assertEqual(value, ['http://testserver/dummyurl/1/'])
def test_dotted_source(self): # def test_dotted_source(self):
""" # """
Source argument should support dotted.source notation. # Source argument should support dotted.source notation.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.HyperlinkedRelatedField( # field = serializers.HyperlinkedRelatedField(
many=True, # many=True,
source='a.b.c', # source='a.b.c',
view_name='dummy-url', # view_name='dummy-url',
) # )
field.context = {'request': request} # field.context = {'request': request}
class ClassWithQuerysetMethod(object): # class ClassWithQuerysetMethod(object):
a = { # a = {
'b': { # 'b': {
'c': BlogPost.objects.all() # 'c': BlogPost.objects.all()
} # }
} # }
obj = ClassWithQuerysetMethod() # obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['http://testserver/dummyurl/1/']) # self.assertEqual(value, ['http://testserver/dummyurl/1/'])
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from .models import OneToOneTarget # from .models import OneToOneTarget
class OneToOneSource(models.Model): # class OneToOneSource(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, related_name='source', # target = models.OneToOneField(OneToOneTarget, related_name='source',
null=True, blank=True) # null=True, blank=True)
class OneToManyTarget(models.Model): # class OneToManyTarget(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
class OneToManySource(models.Model): # class OneToManySource(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
target = models.ForeignKey(OneToManyTarget, related_name='sources') # target = models.ForeignKey(OneToManyTarget, related_name='sources')
class ReverseNestedOneToOneTests(TestCase): # class ReverseNestedOneToOneTests(TestCase):
def setUp(self): # def setUp(self):
class OneToOneSourceSerializer(serializers.ModelSerializer): # class OneToOneSourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = OneToOneSource # model = OneToOneSource
fields = ('id', 'name') # fields = ('id', 'name')
class OneToOneTargetSerializer(serializers.ModelSerializer): # class OneToOneTargetSerializer(serializers.ModelSerializer):
source = OneToOneSourceSerializer() # source = OneToOneSourceSerializer()
class Meta: # class Meta:
model = OneToOneTarget # model = OneToOneTarget
fields = ('id', 'name', 'source') # fields = ('id', 'name', 'source')
self.Serializer = OneToOneTargetSerializer # self.Serializer = OneToOneTargetSerializer
for idx in range(1, 4): # for idx in range(1, 4):
target = OneToOneTarget(name='target-%d' % idx) # target = OneToOneTarget(name='target-%d' % idx)
target.save() # target.save()
source = OneToOneSource(name='source-%d' % idx, target=target) # source = OneToOneSource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_one_to_one_retrieve(self): # def test_one_to_one_retrieve(self):
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
{'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
{'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} # {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create(self): # def test_one_to_one_create(self):
data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} # data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-4') # self.assertEqual(obj.name, 'target-4')
# Ensure (target 4, target_source 4, source 4) are added, and # # Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected. # # everything else is as expected.
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
{'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
{'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, # {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
{'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} # {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self): # def test_one_to_one_create_with_invalid_data(self):
data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} # data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) # self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
def test_one_to_one_update(self): # def test_one_to_one_update(self):
data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} # data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
instance = OneToOneTarget.objects.get(pk=3) # instance = OneToOneTarget.objects.get(pk=3)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3-updated') # self.assertEqual(obj.name, 'target-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated, # # Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected. # # and everything else is as expected.
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
{'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
{'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} # {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
class ForwardNestedOneToOneTests(TestCase): # class ForwardNestedOneToOneTests(TestCase):
def setUp(self): # def setUp(self):
class OneToOneTargetSerializer(serializers.ModelSerializer): # class OneToOneTargetSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = OneToOneTarget # model = OneToOneTarget
fields = ('id', 'name') # fields = ('id', 'name')
class OneToOneSourceSerializer(serializers.ModelSerializer): # class OneToOneSourceSerializer(serializers.ModelSerializer):
target = OneToOneTargetSerializer() # target = OneToOneTargetSerializer()
class Meta: # class Meta:
model = OneToOneSource # model = OneToOneSource
fields = ('id', 'name', 'target') # fields = ('id', 'name', 'target')
self.Serializer = OneToOneSourceSerializer # self.Serializer = OneToOneSourceSerializer
for idx in range(1, 4): # for idx in range(1, 4):
target = OneToOneTarget(name='target-%d' % idx) # target = OneToOneTarget(name='target-%d' % idx)
target.save() # target.save()
source = OneToOneSource(name='source-%d' % idx, target=target) # source = OneToOneSource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_one_to_one_retrieve(self): # def test_one_to_one_retrieve(self):
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} # {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create(self): # def test_one_to_one_create(self):
data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} # data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure (target 4, target_source 4, source 4) are added, and # # Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected. # # everything else is as expected.
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, # {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
{'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} # {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self): # def test_one_to_one_create_with_invalid_data(self):
data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} # data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) # self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
def test_one_to_one_update(self): # def test_one_to_one_update(self):
data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} # data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
instance = OneToOneSource.objects.get(pk=3) # instance = OneToOneSource.objects.get(pk=3)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-3-updated') # self.assertEqual(obj.name, 'source-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated, # # Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected. # # and everything else is as expected.
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} # {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_update_to_null(self): # def test_one_to_one_update_to_null(self):
data = {'id': 3, 'name': 'source-3-updated', 'target': None} # data = {'id': 3, 'name': 'source-3-updated', 'target': None}
instance = OneToOneSource.objects.get(pk=3) # instance = OneToOneSource.objects.get(pk=3)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-3-updated') # self.assertEqual(obj.name, 'source-3-updated')
self.assertEqual(obj.target, None) # self.assertEqual(obj.target, None)
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3-updated', 'target': None} # {'id': 3, 'name': 'source-3-updated', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
# TODO: Nullable 1-1 tests # # TODO: Nullable 1-1 tests
# def test_one_to_one_delete(self): # # def test_one_to_one_delete(self):
# data = {'id': 3, 'name': 'target-3', 'target_source': None} # # data = {'id': 3, 'name': 'target-3', 'target_source': None}
# instance = OneToOneTarget.objects.get(pk=3) # # instance = OneToOneTarget.objects.get(pk=3)
# serializer = self.Serializer(instance, data=data) # # serializer = self.Serializer(instance, data=data)
# self.assertTrue(serializer.is_valid()) # # self.assertTrue(serializer.is_valid())
# serializer.save() # # serializer.save()
# # Ensure (target_source 3, source 3) are deleted, # # # Ensure (target_source 3, source 3) are deleted,
# # and everything else is as expected. # # # and everything else is as expected.
# queryset = OneToOneTarget.objects.all() # # queryset = OneToOneTarget.objects.all()
# serializer = self.Serializer(queryset) # # serializer = self.Serializer(queryset)
# expected = [ # # expected = [
# {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
# {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
# {'id': 3, 'name': 'target-3', 'source': None} # # {'id': 3, 'name': 'target-3', 'source': None}
# ] # # ]
# self.assertEqual(serializer.data, expected) # # self.assertEqual(serializer.data, expected)
class ReverseNestedOneToManyTests(TestCase): # class ReverseNestedOneToManyTests(TestCase):
def setUp(self): # def setUp(self):
class OneToManySourceSerializer(serializers.ModelSerializer): # class OneToManySourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = OneToManySource # model = OneToManySource
fields = ('id', 'name') # fields = ('id', 'name')
class OneToManyTargetSerializer(serializers.ModelSerializer): # class OneToManyTargetSerializer(serializers.ModelSerializer):
sources = OneToManySourceSerializer(many=True, allow_add_remove=True) # sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
class Meta: # class Meta:
model = OneToManyTarget # model = OneToManyTarget
fields = ('id', 'name', 'sources') # fields = ('id', 'name', 'sources')
self.Serializer = OneToManyTargetSerializer # self.Serializer = OneToManyTargetSerializer
target = OneToManyTarget(name='target-1') # target = OneToManyTarget(name='target-1')
target.save() # target.save()
for idx in range(1, 4): # for idx in range(1, 4):
source = OneToManySource(name='source-%d' % idx, target=target) # source = OneToManySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_one_to_many_retrieve(self): # def test_one_to_many_retrieve(self):
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]}, # {'id': 3, 'name': 'source-3'}]},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_many_create(self): # def test_one_to_many_create(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}, # {'id': 3, 'name': 'source-3'},
{'id': 4, 'name': 'source-4'}]} # {'id': 4, 'name': 'source-4'}]}
instance = OneToManyTarget.objects.get(pk=1) # instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-1') # self.assertEqual(obj.name, 'target-1')
# Ensure source 4 is added, and everything else is as # # Ensure source 4 is added, and everything else is as
# expected. # # expected.
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}, # {'id': 3, 'name': 'source-3'},
{'id': 4, 'name': 'source-4'}]} # {'id': 4, 'name': 'source-4'}]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_many_create_with_invalid_data(self): # def test_one_to_many_create_with_invalid_data(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}, # {'id': 3, 'name': 'source-3'},
{'id': 4}]} # {'id': 4}]}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) # self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
def test_one_to_many_update(self): # def test_one_to_many_update(self):
data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, # data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
instance = OneToManyTarget.objects.get(pk=1) # instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-1-updated') # self.assertEqual(obj.name, 'target-1-updated')
# Ensure (target 1, source 1) are updated, # # Ensure (target 1, source 1) are updated,
# and everything else is as expected. # # and everything else is as expected.
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, # {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_many_delete(self): # def test_one_to_many_delete(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
instance = OneToManyTarget.objects.get(pk=1) # instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
# Ensure source 2 is deleted, and everything else is as # # Ensure source 2 is deleted, and everything else is as
# expected. # # expected.
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from django.utils import six # from django.utils import six
from rest_framework import serializers # from rest_framework import serializers
from tests.models import ( # from tests.models import (
BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, # BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, # NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
) # )
# ManyToMany # # ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer): # class ManyToManyTargetSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ManyToManyTarget # model = ManyToManyTarget
fields = ('id', 'name', 'sources') # fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer): # class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ManyToManySource # model = ManyToManySource
fields = ('id', 'name', 'targets') # fields = ('id', 'name', 'targets')
# ForeignKey # # ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer): # class ForeignKeyTargetSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ForeignKeyTarget # model = ForeignKeyTarget
fields = ('id', 'name', 'sources') # fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer): # class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ForeignKeySource # model = ForeignKeySource
fields = ('id', 'name', 'target') # fields = ('id', 'name', 'target')
# Nullable ForeignKey # # Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer): # class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = NullableForeignKeySource # model = NullableForeignKeySource
fields = ('id', 'name', 'target') # fields = ('id', 'name', 'target')
# Nullable OneToOne # # Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer): # class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = OneToOneTarget # model = OneToOneTarget
fields = ('id', 'name', 'nullable_source') # fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid # # TODO: Add test that .data cannot be accessed prior to .is_valid
class PKManyToManyTests(TestCase): # class PKManyToManyTests(TestCase):
def setUp(self): # def setUp(self):
for idx in range(1, 4): # for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) # target = ManyToManyTarget(name='target-%d' % idx)
target.save() # target.save()
source = ManyToManySource(name='source-%d' % idx) # source = ManyToManySource(name='source-%d' % idx)
source.save() # source.save()
for target in ManyToManyTarget.objects.all(): # for target in ManyToManyTarget.objects.all():
source.targets.add(target) # source.targets.add(target)
def test_many_to_many_retrieve(self): # def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() # queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) # serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]}, # {'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]}, # {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} # {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self): # def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() # queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) # serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, # {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]}, # {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]} # {'id': 3, 'name': 'target-3', 'sources': [3]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self): # def test_many_to_many_update(self):
data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} # data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1) # instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data) # serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all() # queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) # serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, # {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]}, # {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} # {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self): # def test_reverse_many_to_many_update(self):
data = {'id': 1, 'name': 'target-1', 'sources': [1]} # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1) # instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data) # serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected # # Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all() # queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) # serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [1]}, # {'id': 1, 'name': 'target-1', 'sources': [1]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]}, # {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]} # {'id': 3, 'name': 'target-3', 'sources': [3]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self): # def test_many_to_many_create(self):
data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} # data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data) # serializer = ManyToManySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # # Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() # queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) # serializer = ManyToManySourceSerializer(queryset, many=True)
self.assertFalse(serializer.fields['targets'].read_only) # self.assertFalse(serializer.fields['targets'].read_only)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]}, # {'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]}, # {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, # {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
{'id': 4, 'name': 'source-4', 'targets': [1, 3]}, # {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self): # def test_reverse_many_to_many_create(self):
data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} # data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data) # serializer = ManyToManyTargetSerializer(data=data)
self.assertFalse(serializer.fields['sources'].read_only) # self.assertFalse(serializer.fields['sources'].read_only)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-4') # self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected # # Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() # queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) # serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, # {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]}, # {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}, # {'id': 3, 'name': 'target-3', 'sources': [3]},
{'id': 4, 'name': 'target-4', 'sources': [1, 3]} # {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
class PKForeignKeyTests(TestCase): # class PKForeignKeyTests(TestCase):
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
new_target = ForeignKeyTarget(name='target-2') # new_target = ForeignKeyTarget(name='target-2')
new_target.save() # new_target.save()
for idx in range(1, 4): # for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target) # source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve(self): # def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 1}, # {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1} # {'id': 3, 'name': 'source-3', 'target': 1}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self): # def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, # {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self): # def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 2} # data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 2}, # {'id': 1, 'name': 'source-1', 'target': 2},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1} # {'id': 3, 'name': 'source-3', 'target': 1}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self): # def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 'foo'} # data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) # self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
def test_reverse_foreign_key_update(self): # def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} # data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2) # instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) # serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save # # We shouldn't have saved anything to the db yet since save
# hasn't been called. # # hasn't been called.
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) # new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, # {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(new_serializer.data, expected) # self.assertEqual(new_serializer.data, expected)
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected # # Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [2]}, # {'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': [1, 3]}, # {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self): # def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 2} # data = {'id': 4, 'name': 'source-4', 'target': 2}
serializer = ForeignKeySourceSerializer(data=data) # serializer = ForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # # Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 1}, # {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}, # {'id': 3, 'name': 'source-3', 'target': 1},
{'id': 4, 'name': 'source-4', 'target': 2}, # {'id': 4, 'name': 'source-4', 'target': 2},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self): # def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} # data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data) # serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3') # self.assertEqual(obj.name, 'target-3')
# Ensure target 3 is added, and everything else is as expected # # Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [2]}, # {'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': [1, 3]}, # {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): # def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} # data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field is required.']}) # self.assertEqual(serializer.errors, {'target': ['This field is required.']})
def test_foreign_key_with_empty(self): # def test_foreign_key_with_empty(self):
""" # """
Regression test for #1072 # Regression test for #1072
https://github.com/tomchristie/django-rest-framework/issues/1072 # https://github.com/tomchristie/django-rest-framework/issues/1072
""" # """
serializer = NullableForeignKeySourceSerializer() # serializer = NullableForeignKeySourceSerializer()
self.assertEqual(serializer.data['target'], None) # self.assertEqual(serializer.data['target'], None)
class PKNullableForeignKeyTests(TestCase): # class PKNullableForeignKeyTests(TestCase):
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
for idx in range(1, 4): # for idx in range(1, 4):
if idx == 3: # if idx == 3:
target = None # target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) # source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve_with_null(self): # def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 1}, # {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): # def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': 'source-4', 'target': None} # data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) # serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 1}, # {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None} # {'id': 4, 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self): # def test_foreign_key_create_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'id': 4, 'name': 'source-4', 'target': ''} # data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None} # expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) # serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 1}, # {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None} # {'id': 4, 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): # def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} # data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) # serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': None}, # {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None} # {'id': 3, 'name': 'source-3', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self): # def test_foreign_key_update_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'id': 1, 'name': 'source-1', 'target': ''} # data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None} # expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) # serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': None}, # {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1}, # {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None} # {'id': 3, 'name': 'source-3', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
# reverse foreign keys MUST be read_only # # reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear() # # In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set. # # and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self): # # def test_reverse_foreign_key_update(self):
# data = {'id': 1, 'name': 'target-1', 'sources': [1]} # # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1) # # instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data) # # serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid()) # # self.assertTrue(serializer.is_valid())
# self.assertEqual(serializer.data, data) # # self.assertEqual(serializer.data, data)
# serializer.save() # # serializer.save()
# # Ensure target 1 is updated, and everything else is as expected # # # Ensure target 1 is updated, and everything else is as expected
# queryset = ForeignKeyTarget.objects.all() # # queryset = ForeignKeyTarget.objects.all()
# serializer = ForeignKeyTargetSerializer(queryset, many=True) # # serializer = ForeignKeyTargetSerializer(queryset, many=True)
# expected = [ # # expected = [
# {'id': 1, 'name': 'target-1', 'sources': [1]}, # # {'id': 1, 'name': 'target-1', 'sources': [1]},
# {'id': 2, 'name': 'target-2', 'sources': []}, # # {'id': 2, 'name': 'target-2', 'sources': []},
# ] # # ]
# self.assertEqual(serializer.data, expected) # # self.assertEqual(serializer.data, expected)
class PKNullableOneToOneTests(TestCase): # class PKNullableOneToOneTests(TestCase):
def setUp(self): # def setUp(self):
target = OneToOneTarget(name='target-1') # target = OneToOneTarget(name='target-1')
target.save() # target.save()
new_target = OneToOneTarget(name='target-2') # new_target = OneToOneTarget(name='target-2')
new_target.save() # new_target.save()
source = NullableOneToOneSource(name='source-1', target=new_target) # source = NullableOneToOneSource(name='source-1', target=new_target)
source.save() # source.save()
def test_reverse_foreign_key_retrieve_with_null(self): # def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True) # serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'nullable_source': None}, # {'id': 1, 'name': 'target-1', 'nullable_source': None},
{'id': 2, 'name': 'target-2', 'nullable_source': 1}, # {'id': 2, 'name': 'target-2', 'nullable_source': 1},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
# The below models and tests ensure that serializer fields corresponding # # The below models and tests ensure that serializer fields corresponding
# to a ManyToManyField field with a user-specified ``through`` model are # # to a ManyToManyField field with a user-specified ``through`` model are
# set to read only # # set to read only
class ManyToManyThroughTarget(models.Model): # class ManyToManyThroughTarget(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
class ManyToManyThrough(models.Model): # class ManyToManyThrough(models.Model):
source = models.ForeignKey('ManyToManyThroughSource') # source = models.ForeignKey('ManyToManyThroughSource')
target = models.ForeignKey(ManyToManyThroughTarget) # target = models.ForeignKey(ManyToManyThroughTarget)
class ManyToManyThroughSource(models.Model): # class ManyToManyThroughSource(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
targets = models.ManyToManyField(ManyToManyThroughTarget, # targets = models.ManyToManyField(ManyToManyThroughTarget,
related_name='sources', # related_name='sources',
through='ManyToManyThrough') # through='ManyToManyThrough')
class ManyToManyThroughTargetSerializer(serializers.ModelSerializer): # class ManyToManyThroughTargetSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ManyToManyThroughTarget # model = ManyToManyThroughTarget
fields = ('id', 'name', 'sources') # fields = ('id', 'name', 'sources')
class ManyToManyThroughSourceSerializer(serializers.ModelSerializer): # class ManyToManyThroughSourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ManyToManyThroughSource # model = ManyToManyThroughSource
fields = ('id', 'name', 'targets') # fields = ('id', 'name', 'targets')
class PKManyToManyThroughTests(TestCase): # class PKManyToManyThroughTests(TestCase):
def setUp(self): # def setUp(self):
self.source = ManyToManyThroughSource.objects.create( # self.source = ManyToManyThroughSource.objects.create(
name='through-source-1') # name='through-source-1')
self.target = ManyToManyThroughTarget.objects.create( # self.target = ManyToManyThroughTarget.objects.create(
name='through-target-1') # name='through-target-1')
def test_many_to_many_create(self): # def test_many_to_many_create(self):
data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]} # data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]}
serializer = ManyToManyThroughSourceSerializer(data=data) # serializer = ManyToManyThroughSourceSerializer(data=data)
self.assertTrue(serializer.fields['targets'].read_only) # self.assertTrue(serializer.fields['targets'].read_only)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(obj.name, 'source-2') # self.assertEqual(obj.name, 'source-2')
self.assertEqual(obj.targets.count(), 0) # self.assertEqual(obj.targets.count(), 0)
def test_many_to_many_reverse_create(self): # def test_many_to_many_reverse_create(self):
data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]} # data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]}
serializer = ManyToManyThroughTargetSerializer(data=data) # serializer = ManyToManyThroughTargetSerializer(data=data)
self.assertTrue(serializer.fields['sources'].read_only) # self.assertTrue(serializer.fields['sources'].read_only)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
obj = serializer.save() # obj = serializer.save()
self.assertEqual(obj.name, 'target-2') # self.assertEqual(obj.name, 'target-2')
self.assertEqual(obj.sources.count(), 0) # self.assertEqual(obj.sources.count(), 0)
# Regression tests for #694 (`source` attribute on related fields) # # Regression tests for #694 (`source` attribute on related fields)
class PrimaryKeyRelatedFieldSourceTests(TestCase): # class PrimaryKeyRelatedFieldSourceTests(TestCase):
def test_related_manager_source(self): # def test_related_manager_source(self):
""" # """
Relational fields should be able to use manager-returning methods as their source. # Relational fields should be able to use manager-returning methods as their source.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') # field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
class ClassWithManagerMethod(object): # class ClassWithManagerMethod(object):
def get_blogposts_manager(self): # def get_blogposts_manager(self):
return BlogPost.objects # return BlogPost.objects
obj = ClassWithManagerMethod() # obj = ClassWithManagerMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, [1]) # self.assertEqual(value, [1])
def test_related_queryset_source(self): # def test_related_queryset_source(self):
""" # """
Relational fields should be able to use queryset-returning methods as their source. # Relational fields should be able to use queryset-returning methods as their source.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') # field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
class ClassWithQuerysetMethod(object): # class ClassWithQuerysetMethod(object):
def get_blogposts_queryset(self): # def get_blogposts_queryset(self):
return BlogPost.objects.all() # return BlogPost.objects.all()
obj = ClassWithQuerysetMethod() # obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, [1]) # self.assertEqual(value, [1])
def test_dotted_source(self): # def test_dotted_source(self):
""" # """
Source argument should support dotted.source notation. # Source argument should support dotted.source notation.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') # field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
class ClassWithQuerysetMethod(object): # class ClassWithQuerysetMethod(object):
a = { # a = {
'b': { # 'b': {
'c': BlogPost.objects.all() # 'c': BlogPost.objects.all()
} # }
} # }
obj = ClassWithQuerysetMethod() # obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, [1]) # self.assertEqual(value, [1])
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget # from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
class ForeignKeyTargetSerializer(serializers.ModelSerializer): # class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.SlugRelatedField(many=True, slug_field='name') # sources = serializers.SlugRelatedField(many=True, slug_field='name')
class Meta: # class Meta:
model = ForeignKeyTarget # model = ForeignKeyTarget
class ForeignKeySourceSerializer(serializers.ModelSerializer): # class ForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name') # target = serializers.SlugRelatedField(slug_field='name')
class Meta: # class Meta:
model = ForeignKeySource # model = ForeignKeySource
class NullableForeignKeySourceSerializer(serializers.ModelSerializer): # class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name', required=False) # target = serializers.SlugRelatedField(slug_field='name', required=False)
class Meta: # class Meta:
model = NullableForeignKeySource # model = NullableForeignKeySource
# TODO: M2M Tests, FKTests (Non-nullable), One2One # # TODO: M2M Tests, FKTests (Non-nullable), One2One
class SlugForeignKeyTests(TestCase): # class SlugForeignKeyTests(TestCase):
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
new_target = ForeignKeyTarget(name='target-2') # new_target = ForeignKeyTarget(name='target-2')
new_target.save() # new_target.save()
for idx in range(1, 4): # for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target) # source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve(self): # def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'} # {'id': 3, 'name': 'source-3', 'target': 'target-1'}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self): # def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self): # def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} # data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-2'}, # {'id': 1, 'name': 'source-1', 'target': 'target-2'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'} # {'id': 3, 'name': 'source-3', 'target': 'target-1'}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self): # def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 123} # data = {'id': 1, 'name': 'source-1', 'target': 123}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) # self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
def test_reverse_foreign_key_update(self): # def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} # data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2) # instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) # serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save # # We shouldn't have saved anything to the db yet since save
# hasn't been called. # # hasn't been called.
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) # new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(new_serializer.data, expected) # self.assertEqual(new_serializer.data, expected)
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected # # Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, # {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self): # def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} # data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
serializer = ForeignKeySourceSerializer(data=data) # serializer = ForeignKeySourceSerializer(data=data)
serializer.is_valid() # serializer.is_valid()
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # # Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}, # {'id': 3, 'name': 'source-3', 'target': 'target-1'},
{'id': 4, 'name': 'source-4', 'target': 'target-2'}, # {'id': 4, 'name': 'source-4', 'target': 'target-2'},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self): # def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} # data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data) # serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3') # self.assertEqual(obj.name, 'target-3')
# Ensure target 3 is added, and everything else is as expected # # Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, # {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): # def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} # data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field is required.']}) # self.assertEqual(serializer.errors, {'target': ['This field is required.']})
class SlugNullableForeignKeyTests(TestCase): # class SlugNullableForeignKeyTests(TestCase):
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
for idx in range(1, 4): # for idx in range(1, 4):
if idx == 3: # if idx == 3:
target = None # target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) # source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve_with_null(self): # def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): # def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': 'source-4', 'target': None} # data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) # serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None} # {'id': 4, 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self): # def test_foreign_key_create_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'id': 4, 'name': 'source-4', 'target': ''} # data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None} # expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) # serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None} # {'id': 4, 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): # def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} # data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) # serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': None}, # {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None} # {'id': 3, 'name': 'source-3', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self): # def test_foreign_key_update_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'id': 1, 'name': 'source-1', 'target': ''} # data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None} # expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) # serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': None}, # {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None} # {'id': 3, 'name': 'source-3', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
...@@ -13,7 +13,7 @@ from rest_framework.compat import yaml, etree, StringIO ...@@ -13,7 +13,7 @@ from rest_framework.compat import yaml, etree, StringIO
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer, UnicodeYAMLRenderer XMLRenderer, JSONPRenderer, BrowsableAPIRenderer
from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
...@@ -32,7 +32,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') ...@@ -32,7 +32,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
expected_results = [ expected_results = [
((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1,2,3]') # Generator
] ]
...@@ -270,7 +270,7 @@ class RendererEndToEndTests(TestCase): ...@@ -270,7 +270,7 @@ class RendererEndToEndTests(TestCase):
self.assertNotContains(resp, '>text/html; charset=utf-8<') self.assertNotContains(resp, '>text/html; charset=utf-8<')
_flat_repr = '{"foo": ["bar", "baz"]}' _flat_repr = '{"foo":["bar","baz"]}'
_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}' _indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
...@@ -373,22 +373,29 @@ class JSONRendererTests(TestCase): ...@@ -373,22 +373,29 @@ class JSONRendererTests(TestCase):
content = renderer.render(obj, 'application/json; indent=2') content = renderer.render(obj, 'application/json; indent=2')
self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
def test_check_ascii(self):
class UnicodeJSONRendererTests(TestCase):
"""
Tests specific for the Unicode JSON Renderer
"""
def test_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8')) self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode('utf-8'))
class UnicodeJSONRendererTests(TestCase): class AsciiJSONRendererTests(TestCase):
""" """
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
def test_proper_encoding(self): def test_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = UnicodeJSONRenderer() renderer = AsciiJSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8')) self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8'))
class JSONPRendererTests(TestCase): class JSONPRendererTests(TestCase):
...@@ -487,13 +494,9 @@ if yaml: ...@@ -487,13 +494,9 @@ if yaml:
def assertYAMLContains(self, content, string): def assertYAMLContains(self, content, string):
self.assertTrue(string in content, '%r not in %r' % (string, content)) self.assertTrue(string in content, '%r not in %r' % (string, content))
class UnicodeYAMLRendererTests(TestCase):
"""
Tests specific for the Unicode YAML Renderer
"""
def test_proper_encoding(self): def test_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = UnicodeYAMLRenderer() renderer = YAMLRenderer()
content = renderer.render(obj, 'application/yaml') content = renderer.render(obj, 'application/yaml')
self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8')) self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8'))
......
...@@ -2,11 +2,12 @@ from __future__ import unicode_literals ...@@ -2,11 +2,12 @@ from __future__ import unicode_literals
from django.conf.urls import patterns, url, include from django.conf.urls import patterns, url, include
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from tests.models import BasicModel, BasicModelSerializer from tests.models import BasicModel
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework import generics from rest_framework import generics
from rest_framework import routers from rest_framework import routers
from rest_framework import serializers
from rest_framework import status from rest_framework import status
from rest_framework.renderers import ( from rest_framework.renderers import (
BaseRenderer, BaseRenderer,
...@@ -17,6 +18,12 @@ from rest_framework import viewsets ...@@ -17,6 +18,12 @@ from rest_framework import viewsets
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
class MockPickleRenderer(BaseRenderer): class MockPickleRenderer(BaseRenderer):
media_type = 'application/pickle' media_type = 'application/pickle'
...@@ -86,14 +93,15 @@ class HTMLView1(APIView): ...@@ -86,14 +93,15 @@ class HTMLView1(APIView):
class HTMLNewModelViewSet(viewsets.ModelViewSet): class HTMLNewModelViewSet(viewsets.ModelViewSet):
model = BasicModel serializer_class = BasicModelSerializer
queryset = BasicModel.objects.all()
class HTMLNewModelView(generics.ListCreateAPIView): class HTMLNewModelView(generics.ListCreateAPIView):
renderer_classes = (BrowsableAPIRenderer,) renderer_classes = (BrowsableAPIRenderer,)
permission_classes = [] permission_classes = []
serializer_class = BasicModelSerializer serializer_class = BasicModelSerializer
model = BasicModel queryset = BasicModel.objects.all()
new_model_viewset_router = routers.DefaultRouter() new_model_viewset_router = routers.DefaultRouter()
...@@ -224,8 +232,8 @@ class Issue467Tests(TestCase): ...@@ -224,8 +232,8 @@ class Issue467Tests(TestCase):
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
class Issue807Tests(TestCase): class Issue807Tests(TestCase):
...@@ -269,11 +277,11 @@ class Issue807Tests(TestCase): ...@@ -269,11 +277,11 @@ class Issue807Tests(TestCase):
) )
resp = self.client.get('/html_new_model_viewset/' + param) resp = self.client.get('/html_new_model_viewset/' + param)
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
...@@ -3,7 +3,7 @@ from django.conf.urls import patterns, url, include ...@@ -3,7 +3,7 @@ from django.conf.urls import patterns, url, include
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, permissions from rest_framework import serializers, viewsets, mixins, permissions
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter from rest_framework.routers import SimpleRouter, DefaultRouter
...@@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase): ...@@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase):
def setUp(self): def setUp(self):
class NoteSerializer(serializers.HyperlinkedModelSerializer): class NoteSerializer(serializers.HyperlinkedModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
class Meta: class Meta:
model = RouterTestModel model = RouterTestModel
lookup_field = 'uuid'
fields = ('url', 'uuid', 'text') fields = ('url', 'uuid', 'text')
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
...@@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase): ...@@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase):
serializer_class = NoteSerializer serializer_class = NoteSerializer
lookup_field = 'uuid' lookup_field = 'uuid'
RouterTestModel.objects.create(uuid='123', text='foo bar')
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes', NoteViewSet)
...@@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase): ...@@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase):
url(r'^', include(self.router.urls)), url(r'^', include(self.router.urls)),
) )
RouterTestModel.objects.create(uuid='123', text='foo bar')
def test_custom_lookup_field_route(self): def test_custom_lookup_field_route(self):
detail_route = self.router.urls[-1] detail_route = self.router.urls[-1]
detail_url_pattern = detail_route.regex.pattern detail_url_pattern = detail_route.regex.pattern
...@@ -284,3 +285,19 @@ class TestDynamicListAndDetailRouter(TestCase): ...@@ -284,3 +285,19 @@ class TestDynamicListAndDetailRouter(TestCase):
else: else:
method_map = 'get' method_map = 'get'
self.assertEqual(route.mapping[method_map], endpoint) self.assertEqual(route.mapping[method_map], endpoint)
class TestRootWithAListlessViewset(TestCase):
def setUp(self):
class NoteViewSet(mixins.RetrieveModelMixin,
viewsets.GenericViewSet):
model = RouterTestModel
self.router = DefaultRouter()
self.router.register(r'notes', NoteViewSet)
self.view = self.router.urls[0].callback
def test_api_root(self):
request = factory.get('/')
response = self.view(request)
self.assertEqual(response.data, {})
This source diff could not be displayed because it is too large. You can view the blob instead.
""" # """
Tests to cover bulk create and update using serializers. # Tests to cover bulk create and update using serializers.
""" # """
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
class BulkCreateSerializerTests(TestCase): # class BulkCreateSerializerTests(TestCase):
""" # """
Creating multiple instances using serializers. # Creating multiple instances using serializers.
""" # """
def setUp(self): # def setUp(self):
class BookSerializer(serializers.Serializer): # class BookSerializer(serializers.Serializer):
id = serializers.IntegerField() # id = serializers.IntegerField()
title = serializers.CharField(max_length=100) # title = serializers.CharField(max_length=100)
author = serializers.CharField(max_length=100) # author = serializers.CharField(max_length=100)
self.BookSerializer = BookSerializer # self.BookSerializer = BookSerializer
def test_bulk_create_success(self): # def test_bulk_create_success(self):
""" # """
Correct bulk update serialization should return the input data. # Correct bulk update serialization should return the input data.
""" # """
data = [ # data = [
{ # {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
}, { # }, {
'id': 1, # 'id': 1,
'title': 'If this is a man', # 'title': 'If this is a man',
'author': 'Primo Levi' # 'author': 'Primo Levi'
}, { # }, {
'id': 2, # 'id': 2,
'title': 'The wind-up bird chronicle', # 'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami' # 'author': 'Haruki Murakami'
} # }
] # ]
serializer = self.BookSerializer(data=data, many=True) # serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, data) # self.assertEqual(serializer.object, data)
def test_bulk_create_errors(self): # def test_bulk_create_errors(self):
""" # """
Correct bulk update serialization should return the input data. # Correct bulk update serialization should return the input data.
""" # """
data = [ # data = [
{ # {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
}, { # }, {
'id': 1, # 'id': 1,
'title': 'If this is a man', # 'title': 'If this is a man',
'author': 'Primo Levi' # 'author': 'Primo Levi'
}, { # }, {
'id': 'foo', # 'id': 'foo',
'title': 'The wind-up bird chronicle', # 'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami' # 'author': 'Haruki Murakami'
} # }
] # ]
expected_errors = [ # expected_errors = [
{}, # {},
{}, # {},
{'id': ['Enter a whole number.']} # {'id': ['Enter a whole number.']}
] # ]
serializer = self.BookSerializer(data=data, many=True) # serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
def test_invalid_list_datatype(self): # def test_invalid_list_datatype(self):
""" # """
Data containing list of incorrect data type should return errors. # Data containing list of incorrect data type should return errors.
""" # """
data = ['foo', 'bar', 'baz'] # data = ['foo', 'bar', 'baz']
serializer = self.BookSerializer(data=data, many=True) # serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
expected_errors = [ # expected_errors = [
{'non_field_errors': ['Invalid data']}, # {'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']}, # {'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']} # {'non_field_errors': ['Invalid data']}
] # ]
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
def test_invalid_single_datatype(self): # def test_invalid_single_datatype(self):
""" # """
Data containing a single incorrect data type should return errors. # Data containing a single incorrect data type should return errors.
""" # """
data = 123 # data = 123
serializer = self.BookSerializer(data=data, many=True) # serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
expected_errors = {'non_field_errors': ['Expected a list of items.']} # expected_errors = {'non_field_errors': ['Expected a list of items.']}
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
def test_invalid_single_object(self): # def test_invalid_single_object(self):
""" # """
Data containing only a single object, instead of a list of objects # Data containing only a single object, instead of a list of objects
should return errors. # should return errors.
""" # """
data = { # data = {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
} # }
serializer = self.BookSerializer(data=data, many=True) # serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
expected_errors = {'non_field_errors': ['Expected a list of items.']} # expected_errors = {'non_field_errors': ['Expected a list of items.']}
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
class BulkUpdateSerializerTests(TestCase): # class BulkUpdateSerializerTests(TestCase):
""" # """
Updating multiple instances using serializers. # Updating multiple instances using serializers.
""" # """
def setUp(self): # def setUp(self):
class Book(object): # class Book(object):
""" # """
A data type that can be persisted to a mock storage backend # A data type that can be persisted to a mock storage backend
with `.save()` and `.delete()`. # with `.save()` and `.delete()`.
""" # """
object_map = {} # object_map = {}
def __init__(self, id, title, author): # def __init__(self, id, title, author):
self.id = id # self.id = id
self.title = title # self.title = title
self.author = author # self.author = author
def save(self): # def save(self):
Book.object_map[self.id] = self # Book.object_map[self.id] = self
def delete(self): # def delete(self):
del Book.object_map[self.id] # del Book.object_map[self.id]
class BookSerializer(serializers.Serializer): # class BookSerializer(serializers.Serializer):
id = serializers.IntegerField() # id = serializers.IntegerField()
title = serializers.CharField(max_length=100) # title = serializers.CharField(max_length=100)
author = serializers.CharField(max_length=100) # author = serializers.CharField(max_length=100)
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
if instance: # if instance:
instance.id = attrs['id'] # instance.id = attrs['id']
instance.title = attrs['title'] # instance.title = attrs['title']
instance.author = attrs['author'] # instance.author = attrs['author']
return instance # return instance
return Book(**attrs) # return Book(**attrs)
self.Book = Book # self.Book = Book
self.BookSerializer = BookSerializer # self.BookSerializer = BookSerializer
data = [ # data = [
{ # {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
}, { # }, {
'id': 1, # 'id': 1,
'title': 'If this is a man', # 'title': 'If this is a man',
'author': 'Primo Levi' # 'author': 'Primo Levi'
}, { # }, {
'id': 2, # 'id': 2,
'title': 'The wind-up bird chronicle', # 'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami' # 'author': 'Haruki Murakami'
} # }
] # ]
for item in data: # for item in data:
book = Book(item['id'], item['title'], item['author']) # book = Book(item['id'], item['title'], item['author'])
book.save() # book.save()
def books(self): # def books(self):
""" # """
Return all the objects in the mock storage backend. # Return all the objects in the mock storage backend.
""" # """
return self.Book.object_map.values() # return self.Book.object_map.values()
def test_bulk_update_success(self): # def test_bulk_update_success(self):
""" # """
Correct bulk update serialization should return the input data. # Correct bulk update serialization should return the input data.
""" # """
data = [ # data = [
{ # {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
}, { # }, {
'id': 2, # 'id': 2,
'title': 'Kafka on the shore', # 'title': 'Kafka on the shore',
'author': 'Haruki Murakami' # 'author': 'Haruki Murakami'
} # }
] # ]
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) # serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
new_data = self.BookSerializer(self.books(), many=True).data # new_data = self.BookSerializer(self.books(), many=True).data
self.assertEqual(data, new_data) # self.assertEqual(data, new_data)
def test_bulk_update_and_create(self): # def test_bulk_update_and_create(self):
""" # """
Bulk update serialization may also include created items. # Bulk update serialization may also include created items.
""" # """
data = [ # data = [
{ # {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
}, { # }, {
'id': 3, # 'id': 3,
'title': 'Kafka on the shore', # 'title': 'Kafka on the shore',
'author': 'Haruki Murakami' # 'author': 'Haruki Murakami'
} # }
] # ]
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) # serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
new_data = self.BookSerializer(self.books(), many=True).data # new_data = self.BookSerializer(self.books(), many=True).data
self.assertEqual(data, new_data) # self.assertEqual(data, new_data)
def test_bulk_update_invalid_create(self): # def test_bulk_update_invalid_create(self):
""" # """
Bulk update serialization without allow_add_remove may not create items. # Bulk update serialization without allow_add_remove may not create items.
""" # """
data = [ # data = [
{ # {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
}, { # }, {
'id': 3, # 'id': 3,
'title': 'Kafka on the shore', # 'title': 'Kafka on the shore',
'author': 'Haruki Murakami' # 'author': 'Haruki Murakami'
} # }
] # ]
expected_errors = [ # expected_errors = [
{}, # {},
{'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} # {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
] # ]
serializer = self.BookSerializer(self.books(), data=data, many=True) # serializer = self.BookSerializer(self.books(), data=data, many=True)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
def test_bulk_update_error(self): # def test_bulk_update_error(self):
""" # """
Incorrect bulk update serialization should return error data. # Incorrect bulk update serialization should return error data.
""" # """
data = [ # data = [
{ # {
'id': 0, # 'id': 0,
'title': 'The electric kool-aid acid test', # 'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe' # 'author': 'Tom Wolfe'
}, { # }, {
'id': 'foo', # 'id': 'foo',
'title': 'Kafka on the shore', # 'title': 'Kafka on the shore',
'author': 'Haruki Murakami' # 'author': 'Haruki Murakami'
} # }
] # ]
expected_errors = [ # expected_errors = [
{}, # {},
{'id': ['Enter a whole number.']} # {'id': ['Enter a whole number.']}
] # ]
serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) # serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
class EmptySerializerTestCase(TestCase): # class EmptySerializerTestCase(TestCase):
def test_empty_serializer(self): # def test_empty_serializer(self):
class FooBarSerializer(serializers.Serializer): # class FooBarSerializer(serializers.Serializer):
foo = serializers.IntegerField() # foo = serializers.IntegerField()
bar = serializers.SerializerMethodField('get_bar') # bar = serializers.SerializerMethodField()
def get_bar(self, obj): # def get_bar(self, obj):
return 'bar' # return 'bar'
serializer = FooBarSerializer() # serializer = FooBarSerializer()
self.assertEquals(serializer.data, {'foo': 0}) # self.assertEquals(serializer.data, {'foo': 0})
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from tests.accounts.serializers import AccountSerializer # from tests.accounts.serializers import AccountSerializer
class ImportingModelSerializerTests(TestCase): # class ImportingModelSerializerTests(TestCase):
""" # """
In some situations like, GH #1225, it is possible, especially in # In some situations like, GH #1225, it is possible, especially in
testing, to import a serializer who's related models have not yet # testing, to import a serializer who's related models have not yet
been resolved by Django. `AccountSerializer` is an example of such # been resolved by Django. `AccountSerializer` is an example of such
a serializer (imported at the top of this file). # a serializer (imported at the top of this file).
""" # """
def test_import_model_serializer(self): # def test_import_model_serializer(self):
""" # """
The serializer at the top of this file should have been # The serializer at the top of this file should have been
imported successfully, and we should be able to instantiate it. # imported successfully, and we should be able to instantiate it.
""" # """
self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer) # self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer)
""" # """
Tests to cover nested serializers. # Tests to cover nested serializers.
Doesn't cover model serializers. # Doesn't cover model serializers.
""" # """
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from . import models # from . import models
class WritableNestedSerializerBasicTests(TestCase): # class WritableNestedSerializerBasicTests(TestCase):
""" # """
Tests for deserializing nested entities. # Tests for deserializing nested entities.
Basic tests that use serializers that simply restore to dicts. # Basic tests that use serializers that simply restore to dicts.
""" # """
def setUp(self): # def setUp(self):
class TrackSerializer(serializers.Serializer): # class TrackSerializer(serializers.Serializer):
order = serializers.IntegerField() # order = serializers.IntegerField()
title = serializers.CharField(max_length=100) # title = serializers.CharField(max_length=100)
duration = serializers.IntegerField() # duration = serializers.IntegerField()
class AlbumSerializer(serializers.Serializer): # class AlbumSerializer(serializers.Serializer):
album_name = serializers.CharField(max_length=100) # album_name = serializers.CharField(max_length=100)
artist = serializers.CharField(max_length=100) # artist = serializers.CharField(max_length=100)
tracks = TrackSerializer(many=True) # tracks = TrackSerializer(many=True)
self.AlbumSerializer = AlbumSerializer # self.AlbumSerializer = AlbumSerializer
def test_nested_validation_success(self): # def test_nested_validation_success(self):
""" # """
Correct nested serialization should return the input data. # Correct nested serialization should return the input data.
""" # """
data = { # data = {
'album_name': 'Discovery', # 'album_name': 'Discovery',
'artist': 'Daft Punk', # 'artist': 'Daft Punk',
'tracks': [ # 'tracks': [
{'order': 1, 'title': 'One More Time', 'duration': 235}, # {'order': 1, 'title': 'One More Time', 'duration': 235},
{'order': 2, 'title': 'Aerodynamic', 'duration': 184}, # {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
{'order': 3, 'title': 'Digital Love', 'duration': 239} # {'order': 3, 'title': 'Digital Love', 'duration': 239}
] # ]
} # }
serializer = self.AlbumSerializer(data=data) # serializer = self.AlbumSerializer(data=data)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, data) # self.assertEqual(serializer.object, data)
def test_nested_validation_error(self): # def test_nested_validation_error(self):
""" # """
Incorrect nested serialization should return appropriate error data. # Incorrect nested serialization should return appropriate error data.
""" # """
data = { # data = {
'album_name': 'Discovery', # 'album_name': 'Discovery',
'artist': 'Daft Punk', # 'artist': 'Daft Punk',
'tracks': [ # 'tracks': [
{'order': 1, 'title': 'One More Time', 'duration': 235}, # {'order': 1, 'title': 'One More Time', 'duration': 235},
{'order': 2, 'title': 'Aerodynamic', 'duration': 184}, # {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
{'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} # {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
] # ]
} # }
expected_errors = { # expected_errors = {
'tracks': [ # 'tracks': [
{}, # {},
{}, # {},
{'duration': ['Enter a whole number.']} # {'duration': ['Enter a whole number.']}
] # ]
} # }
serializer = self.AlbumSerializer(data=data) # serializer = self.AlbumSerializer(data=data)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
def test_many_nested_validation_error(self): # def test_many_nested_validation_error(self):
""" # """
Incorrect nested serialization should return appropriate error data # Incorrect nested serialization should return appropriate error data
when multiple entities are being deserialized. # when multiple entities are being deserialized.
""" # """
data = [ # data = [
{ # {
'album_name': 'Russian Red', # 'album_name': 'Russian Red',
'artist': 'I Love Your Glasses', # 'artist': 'I Love Your Glasses',
'tracks': [ # 'tracks': [
{'order': 1, 'title': 'Cigarettes', 'duration': 121}, # {'order': 1, 'title': 'Cigarettes', 'duration': 121},
{'order': 2, 'title': 'No Past Land', 'duration': 198}, # {'order': 2, 'title': 'No Past Land', 'duration': 198},
{'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} # {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
] # ]
}, # },
{ # {
'album_name': 'Discovery', # 'album_name': 'Discovery',
'artist': 'Daft Punk', # 'artist': 'Daft Punk',
'tracks': [ # 'tracks': [
{'order': 1, 'title': 'One More Time', 'duration': 235}, # {'order': 1, 'title': 'One More Time', 'duration': 235},
{'order': 2, 'title': 'Aerodynamic', 'duration': 184}, # {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
{'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} # {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
] # ]
} # }
] # ]
expected_errors = [ # expected_errors = [
{}, # {},
{ # {
'tracks': [ # 'tracks': [
{}, # {},
{}, # {},
{'duration': ['Enter a whole number.']} # {'duration': ['Enter a whole number.']}
] # ]
} # }
] # ]
serializer = self.AlbumSerializer(data=data, many=True) # serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False) # self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors) # self.assertEqual(serializer.errors, expected_errors)
class WritableNestedSerializerObjectTests(TestCase): # class WritableNestedSerializerObjectTests(TestCase):
""" # """
Tests for deserializing nested entities. # Tests for deserializing nested entities.
These tests use serializers that restore to concrete objects. # These tests use serializers that restore to concrete objects.
""" # """
def setUp(self): # def setUp(self):
# Couple of concrete objects that we're going to deserialize into # # Couple of concrete objects that we're going to deserialize into
class Track(object): # class Track(object):
def __init__(self, order, title, duration): # def __init__(self, order, title, duration):
self.order, self.title, self.duration = order, title, duration # self.order, self.title, self.duration = order, title, duration
def __eq__(self, other): # def __eq__(self, other):
return ( # return (
self.order == other.order and # self.order == other.order and
self.title == other.title and # self.title == other.title and
self.duration == other.duration # self.duration == other.duration
) # )
class Album(object): # class Album(object):
def __init__(self, album_name, artist, tracks): # def __init__(self, album_name, artist, tracks):
self.album_name, self.artist, self.tracks = album_name, artist, tracks # self.album_name, self.artist, self.tracks = album_name, artist, tracks
def __eq__(self, other): # def __eq__(self, other):
return ( # return (
self.album_name == other.album_name and # self.album_name == other.album_name and
self.artist == other.artist and # self.artist == other.artist and
self.tracks == other.tracks # self.tracks == other.tracks
) # )
# And their corresponding serializers # # And their corresponding serializers
class TrackSerializer(serializers.Serializer): # class TrackSerializer(serializers.Serializer):
order = serializers.IntegerField() # order = serializers.IntegerField()
title = serializers.CharField(max_length=100) # title = serializers.CharField(max_length=100)
duration = serializers.IntegerField() # duration = serializers.IntegerField()
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
return Track(attrs['order'], attrs['title'], attrs['duration']) # return Track(attrs['order'], attrs['title'], attrs['duration'])
class AlbumSerializer(serializers.Serializer): # class AlbumSerializer(serializers.Serializer):
album_name = serializers.CharField(max_length=100) # album_name = serializers.CharField(max_length=100)
artist = serializers.CharField(max_length=100) # artist = serializers.CharField(max_length=100)
tracks = TrackSerializer(many=True) # tracks = TrackSerializer(many=True)
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) # return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
self.Album, self.Track = Album, Track # self.Album, self.Track = Album, Track
self.AlbumSerializer = AlbumSerializer # self.AlbumSerializer = AlbumSerializer
def test_nested_validation_success(self): # def test_nested_validation_success(self):
""" # """
Correct nested serialization should return a restored object # Correct nested serialization should return a restored object
that corresponds to the input data. # that corresponds to the input data.
""" # """
data = { # data = {
'album_name': 'Discovery', # 'album_name': 'Discovery',
'artist': 'Daft Punk', # 'artist': 'Daft Punk',
'tracks': [ # 'tracks': [
{'order': 1, 'title': 'One More Time', 'duration': 235}, # {'order': 1, 'title': 'One More Time', 'duration': 235},
{'order': 2, 'title': 'Aerodynamic', 'duration': 184}, # {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
{'order': 3, 'title': 'Digital Love', 'duration': 239} # {'order': 3, 'title': 'Digital Love', 'duration': 239}
] # ]
} # }
expected_object = self.Album( # expected_object = self.Album(
album_name='Discovery', # album_name='Discovery',
artist='Daft Punk', # artist='Daft Punk',
tracks=[ # tracks=[
self.Track(order=1, title='One More Time', duration=235), # self.Track(order=1, title='One More Time', duration=235),
self.Track(order=2, title='Aerodynamic', duration=184), # self.Track(order=2, title='Aerodynamic', duration=184),
self.Track(order=3, title='Digital Love', duration=239), # self.Track(order=3, title='Digital Love', duration=239),
] # ]
) # )
serializer = self.AlbumSerializer(data=data) # serializer = self.AlbumSerializer(data=data)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected_object) # self.assertEqual(serializer.object, expected_object)
def test_many_nested_validation_success(self): # def test_many_nested_validation_success(self):
""" # """
Correct nested serialization should return multiple restored objects # Correct nested serialization should return multiple restored objects
that corresponds to the input data when multiple objects are # that corresponds to the input data when multiple objects are
being deserialized. # being deserialized.
""" # """
data = [ # data = [
{ # {
'album_name': 'Russian Red', # 'album_name': 'Russian Red',
'artist': 'I Love Your Glasses', # 'artist': 'I Love Your Glasses',
'tracks': [ # 'tracks': [
{'order': 1, 'title': 'Cigarettes', 'duration': 121}, # {'order': 1, 'title': 'Cigarettes', 'duration': 121},
{'order': 2, 'title': 'No Past Land', 'duration': 198}, # {'order': 2, 'title': 'No Past Land', 'duration': 198},
{'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} # {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
] # ]
}, # },
{ # {
'album_name': 'Discovery', # 'album_name': 'Discovery',
'artist': 'Daft Punk', # 'artist': 'Daft Punk',
'tracks': [ # 'tracks': [
{'order': 1, 'title': 'One More Time', 'duration': 235}, # {'order': 1, 'title': 'One More Time', 'duration': 235},
{'order': 2, 'title': 'Aerodynamic', 'duration': 184}, # {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
{'order': 3, 'title': 'Digital Love', 'duration': 239} # {'order': 3, 'title': 'Digital Love', 'duration': 239}
] # ]
} # }
] # ]
expected_object = [ # expected_object = [
self.Album( # self.Album(
album_name='Russian Red', # album_name='Russian Red',
artist='I Love Your Glasses', # artist='I Love Your Glasses',
tracks=[ # tracks=[
self.Track(order=1, title='Cigarettes', duration=121), # self.Track(order=1, title='Cigarettes', duration=121),
self.Track(order=2, title='No Past Land', duration=198), # self.Track(order=2, title='No Past Land', duration=198),
self.Track(order=3, title='They Don\'t Believe', duration=191), # self.Track(order=3, title='They Don\'t Believe', duration=191),
] # ]
), # ),
self.Album( # self.Album(
album_name='Discovery', # album_name='Discovery',
artist='Daft Punk', # artist='Daft Punk',
tracks=[ # tracks=[
self.Track(order=1, title='One More Time', duration=235), # self.Track(order=1, title='One More Time', duration=235),
self.Track(order=2, title='Aerodynamic', duration=184), # self.Track(order=2, title='Aerodynamic', duration=184),
self.Track(order=3, title='Digital Love', duration=239), # self.Track(order=3, title='Digital Love', duration=239),
] # ]
) # )
] # ]
serializer = self.AlbumSerializer(data=data, many=True) # serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected_object) # self.assertEqual(serializer.object, expected_object)
class ForeignKeyNestedSerializerUpdateTests(TestCase): # class ForeignKeyNestedSerializerUpdateTests(TestCase):
def setUp(self): # def setUp(self):
class Artist(object): # class Artist(object):
def __init__(self, name): # def __init__(self, name):
self.name = name # self.name = name
def __eq__(self, other): # def __eq__(self, other):
return self.name == other.name # return self.name == other.name
class Album(object): # class Album(object):
def __init__(self, name, artist): # def __init__(self, name, artist):
self.name, self.artist = name, artist # self.name, self.artist = name, artist
def __eq__(self, other): # def __eq__(self, other):
return self.name == other.name and self.artist == other.artist # return self.name == other.name and self.artist == other.artist
class ArtistSerializer(serializers.Serializer): # class ArtistSerializer(serializers.Serializer):
name = serializers.CharField() # name = serializers.CharField()
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
if instance: # if instance:
instance.name = attrs['name'] # instance.name = attrs['name']
else: # else:
instance = Artist(attrs['name']) # instance = Artist(attrs['name'])
return instance # return instance
class AlbumSerializer(serializers.Serializer): # class AlbumSerializer(serializers.Serializer):
name = serializers.CharField() # name = serializers.CharField()
by = ArtistSerializer(source='artist') # by = ArtistSerializer(source='artist')
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
if instance: # if instance:
instance.name = attrs['name'] # instance.name = attrs['name']
instance.artist = attrs['artist'] # instance.artist = attrs['artist']
else: # else:
instance = Album(attrs['name'], attrs['artist']) # instance = Album(attrs['name'], attrs['artist'])
return instance # return instance
self.Artist = Artist # self.Artist = Artist
self.Album = Album # self.Album = Album
self.AlbumSerializer = AlbumSerializer # self.AlbumSerializer = AlbumSerializer
def test_create_via_foreign_key_with_source(self): # def test_create_via_foreign_key_with_source(self):
""" # """
Check that we can both *create* and *update* into objects across # Check that we can both *create* and *update* into objects across
ForeignKeys that have a `source` specified. # ForeignKeys that have a `source` specified.
Regression test for #1170 # Regression test for #1170
""" # """
data = { # data = {
'name': 'Discovery', # 'name': 'Discovery',
'by': {'name': 'Daft Punk'}, # 'by': {'name': 'Daft Punk'},
} # }
expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery') # expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery')
# create # # create
serializer = self.AlbumSerializer(data=data) # serializer = self.AlbumSerializer(data=data)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected) # self.assertEqual(serializer.object, expected)
# update # # update
original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters') # original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters')
serializer = self.AlbumSerializer(instance=original, data=data) # serializer = self.AlbumSerializer(instance=original, data=data)
self.assertEqual(serializer.is_valid(), True) # self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected) # self.assertEqual(serializer.object, expected)
class NestedModelSerializerUpdateTests(TestCase): # class NestedModelSerializerUpdateTests(TestCase):
def test_second_nested_level(self): # def test_second_nested_level(self):
john = models.Person.objects.create(name="john") # john = models.Person.objects.create(name="john")
post = john.blogpost_set.create(title="Test blog post") # post = john.blogpost_set.create(title="Test blog post")
post.blogpostcomment_set.create(text="I hate this blog post") # post.blogpostcomment_set.create(text="I hate this blog post")
post.blogpostcomment_set.create(text="I love this blog post") # post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostCommentSerializer(serializers.ModelSerializer): # class BlogPostCommentSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = models.BlogPostComment # model = models.BlogPostComment
class BlogPostSerializer(serializers.ModelSerializer): # class BlogPostSerializer(serializers.ModelSerializer):
comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set') # comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
class Meta: # class Meta:
model = models.BlogPost # model = models.BlogPost
fields = ('id', 'title', 'comments') # fields = ('id', 'title', 'comments')
class PersonSerializer(serializers.ModelSerializer): # class PersonSerializer(serializers.ModelSerializer):
posts = BlogPostSerializer(many=True, source='blogpost_set') # posts = BlogPostSerializer(many=True, source='blogpost_set')
class Meta: # class Meta:
model = models.Person # model = models.Person
fields = ('id', 'name', 'age', 'posts') # fields = ('id', 'name', 'age', 'posts')
serialize = PersonSerializer(instance=john) # serialize = PersonSerializer(instance=john)
deserialize = PersonSerializer(data=serialize.data, instance=john) # deserialize = PersonSerializer(data=serialize.data, instance=john)
self.assertTrue(deserialize.is_valid()) # self.assertTrue(deserialize.is_valid())
result = deserialize.object # result = deserialize.object
result.save() # result.save()
self.assertEqual(result.id, john.id) # self.assertEqual(result.id, john.id)
...@@ -109,7 +109,7 @@ class ThrottlingTests(TestCase): ...@@ -109,7 +109,7 @@ class ThrottlingTests(TestCase):
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
""" """
Ensure the response returns an X-Throttle field with status and next attributes Ensure the response returns an Retry-After field with status and next attributes
set properly. set properly.
""" """
request = self.factory.get('/') request = self.factory.get('/')
...@@ -117,10 +117,8 @@ class ThrottlingTests(TestCase): ...@@ -117,10 +117,8 @@ class ThrottlingTests(TestCase):
self.set_throttle_timer(view, timer) self.set_throttle_timer(view, timer)
response = view.as_view()(request) response = view.as_view()(request)
if expect is not None: if expect is not None:
self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
self.assertEqual(response['Retry-After'], expect) self.assertEqual(response['Retry-After'], expect)
else: else:
self.assertFalse('X-Throttle-Wait-Seconds' in response)
self.assertFalse('Retry-After' in response) self.assertFalse('Retry-After' in response)
def test_seconds_fields(self): def test_seconds_fields(self):
...@@ -173,13 +171,11 @@ class ThrottlingTests(TestCase): ...@@ -173,13 +171,11 @@ class ThrottlingTests(TestCase):
self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
response = MockView_NonTimeThrottling.as_view()(request) response = MockView_NonTimeThrottling.as_view()(request)
self.assertFalse('X-Throttle-Wait-Seconds' in response)
self.assertFalse('Retry-After' in response) self.assertFalse('Retry-After' in response)
self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
response = MockView_NonTimeThrottling.as_view()(request) response = MockView_NonTimeThrottling.as_view()(request)
self.assertFalse('X-Throttle-Wait-Seconds' in response)
self.assertFalse('Retry-After' in response) self.assertFalse('Retry-After' in response)
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.validators import MaxValueValidator from django.core.validators import MaxValueValidator
from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers, status from rest_framework import generics, serializers, status
...@@ -22,23 +23,10 @@ class ValidationModelSerializer(serializers.ModelSerializer): ...@@ -22,23 +23,10 @@ class ValidationModelSerializer(serializers.ModelSerializer):
class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
model = ValidationModel queryset = ValidationModel.objects.all()
serializer_class = ValidationModelSerializer serializer_class = ValidationModelSerializer
class TestPreSaveValidationExclusions(TestCase):
def test_pre_save_validation_exclusions(self):
"""
Somewhat weird test case to ensure that we don't perform model
validation on read only fields.
"""
obj = ValidationModel.objects.create(blank_validated_field='')
request = factory.put('/', {}, format='json')
view = UpdateValidationModel().as_view()
response = view(request, pk=obj.pk).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Regression for #653 # Regression for #653
class ShouldValidateModel(models.Model): class ShouldValidateModel(models.Model):
...@@ -48,11 +36,10 @@ class ShouldValidateModel(models.Model): ...@@ -48,11 +36,10 @@ class ShouldValidateModel(models.Model):
class ShouldValidateModelSerializer(serializers.ModelSerializer): class ShouldValidateModelSerializer(serializers.ModelSerializer):
renamed = serializers.CharField(source='should_validate_field', required=False) renamed = serializers.CharField(source='should_validate_field', required=False)
def validate_renamed(self, attrs, source): def validate_renamed(self, value):
value = attrs[source]
if len(value) < 3: if len(value) < 3:
raise serializers.ValidationError('Minimum 3 characters.') raise serializers.ValidationError('Minimum 3 characters.')
return attrs return value
class Meta: class Meta:
model = ShouldValidateModel model = ShouldValidateModel
...@@ -117,7 +104,7 @@ class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): ...@@ -117,7 +104,7 @@ class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer):
class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
model = ValidationMaxValueValidatorModel queryset = ValidationMaxValueValidatorModel.objects.all()
serializer_class = ValidationMaxValueValidatorModelSerializer serializer_class = ValidationMaxValueValidatorModelSerializer
...@@ -144,5 +131,44 @@ class TestMaxValueValidatorValidation(TestCase): ...@@ -144,5 +131,44 @@ class TestMaxValueValidatorValidation(TestCase):
request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json')
view = UpdateMaxValueValidationModel().as_view() view = UpdateMaxValueValidationModel().as_view()
response = view(request, pk=obj.pk).render() response = view(request, pk=obj.pk).render()
self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') self.assertEqual(response.content, b'{"number_value":["Ensure this value is less than or equal to 100."]}')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
class TestChoiceFieldChoicesValidate(TestCase):
CHOICES = [
(0, 'Small'),
(1, 'Medium'),
(2, 'Large'),
]
CHOICES_NESTED = [
('Category', (
(1, 'First'),
(2, 'Second'),
(3, 'Third'),
)),
(4, 'Fourth'),
]
def test_choices(self):
"""
Make sure a value for choices works as expected.
"""
f = serializers.ChoiceField(choices=self.CHOICES)
value = self.CHOICES[0][0]
try:
f.to_internal_value(value)
except ValidationError:
self.fail("Value %s does not validate" % str(value))
# def test_nested_choices(self):
# """
# Make sure a nested value for choices works as expected.
# """
# f = serializers.ChoiceField(choices=self.CHOICES_NESTED)
# value = self.CHOICES_NESTED[0][1][0][0]
# try:
# f.to_native(value)
# except ValidationError:
# self.fail("Value %s does not validate" % str(value))
from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
class ExampleModel(models.Model):
email = models.EmailField(max_length=100)
password = models.CharField(max_length=100)
class WriteOnlyFieldTests(TestCase): class WriteOnlyFieldTests(TestCase):
def test_write_only_fields(self): def setUp(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
email = serializers.EmailField() email = serializers.EmailField()
password = serializers.CharField(write_only=True) password = serializers.CharField(write_only=True)
def create(self, attrs):
return attrs
self.Serializer = ExampleSerializer
def write_only_fields_are_present_on_input(self):
data = { data = {
'email': 'foo@example.com', 'email': 'foo@example.com',
'password': '123' 'password': '123'
} }
serializer = ExampleSerializer(data=data) serializer = self.Serializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.object, data) self.assertEquals(serializer.validated_data, data)
self.assertEquals(serializer.data, {'email': 'foo@example.com'})
def test_write_only_fields_meta(self):
class ExampleSerializer(serializers.ModelSerializer):
class Meta:
model = ExampleModel
fields = ('email', 'password')
write_only_fields = ('password',)
data = { def write_only_fields_are_not_present_on_output(self):
instance = {
'email': 'foo@example.com', 'email': 'foo@example.com',
'password': '123' 'password': '123'
} }
serializer = ExampleSerializer(data=data) serializer = self.Serializer(instance)
self.assertTrue(serializer.is_valid())
self.assertTrue(isinstance(serializer.object, ExampleModel))
self.assertEquals(serializer.object.email, data['email'])
self.assertEquals(serializer.object.password, data['password'])
self.assertEquals(serializer.data, {'email': 'foo@example.com'}) self.assertEquals(serializer.data, {'email': 'foo@example.com'})
from contextlib import contextmanager from contextlib import contextmanager
from django.core.exceptions import ObjectDoesNotExist
from django.core.urlresolvers import NoReverseMatch
from django.utils import six from django.utils import six
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
...@@ -23,3 +25,54 @@ def temporary_setting(setting, value, module=None): ...@@ -23,3 +25,54 @@ def temporary_setting(setting, value, module=None):
if module is not None: if module is not None:
six.moves.reload_module(module) six.moves.reload_module(module)
class MockObject(object):
def __init__(self, **kwargs):
self._kwargs = kwargs
for key, val in kwargs.items():
setattr(self, key, val)
def __str__(self):
kwargs_str = ', '.join([
'%s=%s' % (key, value)
for key, value in sorted(self._kwargs.items())
])
return '<MockObject %s>' % kwargs_str
class MockQueryset(object):
def __init__(self, iterable):
self.items = iterable
def get(self, **lookup):
for item in self.items:
if all([
getattr(item, key, None) == value
for key, value in lookup.items()
]):
return item
raise ObjectDoesNotExist()
class BadType(object):
"""
When used as a lookup with a `MockQueryset`, these objects
will raise a `TypeError`, as occurs in Django when making
queryset lookups with an incorrect type for the lookup value.
"""
def __eq__(self):
raise TypeError()
def mock_reverse(view_name, args=None, kwargs=None, request=None, format=None):
args = args or []
kwargs = kwargs or {}
value = (args + list(kwargs.values()) + ['-'])[0]
prefix = 'http://example.org' if request else ''
suffix = ('.' + format) if (format is not None) else ''
return '%s/%s/%s%s/' % (prefix, view_name, value, suffix)
def fail_reverse(view_name, args=None, kwargs=None, request=None, format=None):
raise NoReverseMatch()
from rest_framework import generics
from .models import NullableForeignKeySource
from .serializers import NullableFKSourceSerializer
class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
model = NullableForeignKeySource
model_serializer_class = NullableFKSourceSerializer
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment