Commit e40000c8 by Tom Christie

Merge pull request #408 from markotibold/file_and_image_fields

Added a FileField and an ImageField
parents 31f01bd6 f801e5d3
...@@ -165,6 +165,33 @@ A floating point representation. ...@@ -165,6 +165,33 @@ A floating point representation.
Corresponds to `django.db.models.fields.FloatField`. Corresponds to `django.db.models.fields.FloatField`.
## FileField
A file representation. Performs Django's standard FileField validation.
Corresponds to `django.forms.fields.FileField`.
**Signature:** `FileField(max_length=None, allow_empty_file=False)`
- `max_length` designates the maximum length for the file name.
- `allow_empty_file` designates if empty files are allowed.
## ImageField
An image representation.
Corresponds to `django.forms.fields.ImageField`.
Requires the `PIL` package.
Signature and validation is the same as with `FileField`.
---
**Note:** `FileFields` and `ImageFields` are only suitable for use with MultiPartParser, since eg json doesn't support file uploads.
Django's regular [FILE_UPLOAD_HANDLERS] are used for handling uploaded files.
--- ---
# Relational Fields # Relational Fields
...@@ -286,3 +313,4 @@ This field is always read-only. ...@@ -286,3 +313,4 @@ This field is always read-only.
* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`. * `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`.
[cite]: http://www.python.org/dev/peps/pep-0020/ [cite]: http://www.python.org/dev/peps/pep-0020/
[FILE_UPLOAD_HANDLERS]: https://docs.djangoproject.com/en/dev/ref/settings/#std:setting-FILE_UPLOAD_HANDLERS
...@@ -3,6 +3,8 @@ import datetime ...@@ -3,6 +3,8 @@ import datetime
import inspect import inspect
import warnings import warnings
from io import BytesIO
from django.core import validators from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix from django.core.urlresolvers import resolve, get_script_prefix
...@@ -31,6 +33,7 @@ class Field(object): ...@@ -31,6 +33,7 @@ class Field(object):
creation_counter = 0 creation_counter = 0
empty = '' empty = ''
type_name = None type_name = None
_use_files = None
def __init__(self, source=None): def __init__(self, source=None):
self.parent = None self.parent = None
...@@ -51,7 +54,7 @@ class Field(object): ...@@ -51,7 +54,7 @@ class Field(object):
self.root = parent.root or parent self.root = parent.root or parent
self.context = self.root.context self.context = self.root.context
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
""" """
Given a dictionary and a field name, updates the dictionary `into`, Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value. with the field and it's deserialized value.
...@@ -166,7 +169,7 @@ class WritableField(Field): ...@@ -166,7 +169,7 @@ class WritableField(Field):
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
""" """
Given a dictionary and a field name, updates the dictionary `into`, Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value. with the field and it's deserialized value.
...@@ -175,6 +178,9 @@ class WritableField(Field): ...@@ -175,6 +178,9 @@ class WritableField(Field):
return return
try: try:
if self._use_files:
native = files[field_name]
else:
native = data[field_name] native = data[field_name]
except KeyError: except KeyError:
if self.default is not None: if self.default is not None:
...@@ -323,7 +329,7 @@ class RelatedField(WritableField): ...@@ -323,7 +329,7 @@ class RelatedField(WritableField):
value = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
return self.to_native(value) return self.to_native(value)
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
if self.read_only: if self.read_only:
return return
...@@ -341,7 +347,7 @@ class ManyRelatedMixin(object): ...@@ -341,7 +347,7 @@ class ManyRelatedMixin(object):
value = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()] return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
if self.read_only: if self.read_only:
return return
...@@ -904,3 +910,95 @@ class FloatField(WritableField): ...@@ -904,3 +910,95 @@ class FloatField(WritableField):
except (TypeError, ValueError): except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value msg = self.error_messages['invalid'] % value
raise ValidationError(msg) raise ValidationError(msg)
class FileField(WritableField):
_use_files = True
type_name = 'FileField'
widget = widgets.FileInput
default_error_messages = {
'invalid': _("No file was submitted. Check the encoding type on the form."),
'missing': _("No file was submitted."),
'empty': _("The submitted file is empty."),
'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
}
def __init__(self, *args, **kwargs):
self.max_length = kwargs.pop('max_length', None)
self.allow_empty_file = kwargs.pop('allow_empty_file', False)
super(FileField, self).__init__(*args, **kwargs)
def from_native(self, data):
if data in validators.EMPTY_VALUES:
return None
# UploadedFile objects should have name and size attributes.
try:
file_name = data.name
file_size = data.size
except AttributeError:
raise ValidationError(self.error_messages['invalid'])
if self.max_length is not None and len(file_name) > self.max_length:
error_values = {'max': self.max_length, 'length': len(file_name)}
raise ValidationError(self.error_messages['max_length'] % error_values)
if not file_name:
raise ValidationError(self.error_messages['invalid'])
if not self.allow_empty_file and not file_size:
raise ValidationError(self.error_messages['empty'])
return data
def to_native(self, value):
return value.name
class ImageField(FileField):
_use_files = True
default_error_messages = {
'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
}
def from_native(self, data):
"""
Checks that the file-upload field data contains a valid image (GIF, JPG,
PNG, possibly others -- whatever the Python Imaging Library supports).
"""
f = super(ImageField, self).from_native(data)
if f is None:
return None
# Try to import PIL in either of the two ways it can end up installed.
try:
from PIL import Image
except ImportError:
import Image
# We need to get a file object for PIL. We might have a path or we might
# have to read the data into memory.
if hasattr(data, 'temporary_file_path'):
file = data.temporary_file_path()
else:
if hasattr(data, 'read'):
file = BytesIO(data.read())
else:
file = BytesIO(data['content'])
try:
# load() could spot a truncated JPEG, but it loads the entire
# image in memory, which is a DoS vector. See #3848 and #18520.
# verify() must be called immediately after the constructor.
Image.open(file).verify()
except ImportError:
# Under PyPy, it is possible to import PIL. However, the underlying
# _imaging C module isn't available, so an ImportError will be
# raised. Catch and re-raise.
raise
except Exception: # Python Imaging Library doesn't recognize it as an image
raise ValidationError(self.error_messages['invalid_image'])
if hasattr(f, 'seek') and callable(f.seek):
f.seek(0)
return f
...@@ -47,11 +47,10 @@ class GenericAPIView(views.APIView): ...@@ -47,11 +47,10 @@ class GenericAPIView(views.APIView):
return serializer_class return serializer_class
def get_serializer(self, instance=None, data=None, files=None): def get_serializer(self, instance=None, data=None, files=None):
# TODO: add support for files
# TODO: add support for seperate serializer/deserializer # TODO: add support for seperate serializer/deserializer
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
context = self.get_serializer_context() context = self.get_serializer_context()
return serializer_class(instance, data=data, context=context) return serializer_class(instance, data=data, files=files, context=context)
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
......
...@@ -15,7 +15,7 @@ class CreateModelMixin(object): ...@@ -15,7 +15,7 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`. Should be mixed in with any `BaseView`.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA) serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)
self.object = serializer.save() self.object = serializer.save()
...@@ -89,7 +89,7 @@ class UpdateModelMixin(object): ...@@ -89,7 +89,7 @@ class UpdateModelMixin(object):
self.object = None self.object = None
created = True created = True
serializer = self.get_serializer(self.object, data=request.DATA) serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)
......
...@@ -320,7 +320,9 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -320,7 +320,9 @@ class BrowsableAPIRenderer(BaseRenderer):
serializers.SlugRelatedField: forms.ChoiceField, serializers.SlugRelatedField: forms.ChoiceField,
serializers.ManySlugRelatedField: forms.MultipleChoiceField, serializers.ManySlugRelatedField: forms.MultipleChoiceField,
serializers.HyperlinkedRelatedField: forms.ChoiceField, serializers.HyperlinkedRelatedField: forms.ChoiceField,
serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField,
serializers.FileField: forms.FileField,
serializers.ImageField: forms.ImageField,
} }
fields = {} fields = {}
......
...@@ -91,7 +91,7 @@ class BaseSerializer(Field): ...@@ -91,7 +91,7 @@ class BaseSerializer(Field):
_options_class = SerializerOptions _options_class = SerializerOptions
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations. _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations.
def __init__(self, instance=None, data=None, context=None, **kwargs): def __init__(self, instance=None, data=None, files=None, context=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs) super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
self.fields = copy.deepcopy(self.base_fields) self.fields = copy.deepcopy(self.base_fields)
...@@ -101,9 +101,11 @@ class BaseSerializer(Field): ...@@ -101,9 +101,11 @@ class BaseSerializer(Field):
self.context = context or {} self.context = context or {}
self.init_data = data self.init_data = data
self.init_files = files
self.object = instance self.object = instance
self._data = None self._data = None
self._files = None
self._errors = None self._errors = None
##### #####
...@@ -187,7 +189,7 @@ class BaseSerializer(Field): ...@@ -187,7 +189,7 @@ class BaseSerializer(Field):
ret.fields[key] = field ret.fields[key] = field
return ret return ret
def restore_fields(self, data): def restore_fields(self, data, files):
""" """
Core of deserialization, together with `restore_object`. Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields. Converts a dictionary of data into a dictionary of deserialized fields.
...@@ -196,7 +198,7 @@ class BaseSerializer(Field): ...@@ -196,7 +198,7 @@ class BaseSerializer(Field):
reverted_data = {} reverted_data = {}
for field_name, field in fields.items(): for field_name, field in fields.items():
try: try:
field.field_from_native(data, field_name, reverted_data) field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err: except ValidationError as err:
self._errors[field_name] = list(err.messages) self._errors[field_name] = list(err.messages)
...@@ -250,7 +252,7 @@ class BaseSerializer(Field): ...@@ -250,7 +252,7 @@ class BaseSerializer(Field):
return [self.convert_object(item) for item in obj] return [self.convert_object(item) for item in obj]
return self.convert_object(obj) return self.convert_object(obj)
def from_native(self, data): def from_native(self, data, files):
""" """
Deserialize primitives -> objects. Deserialize primitives -> objects.
""" """
...@@ -259,8 +261,8 @@ class BaseSerializer(Field): ...@@ -259,8 +261,8 @@ class BaseSerializer(Field):
return (self.from_native(item) for item in data) return (self.from_native(item) for item in data)
self._errors = {} self._errors = {}
if data is not None: if data is not None or files is not None:
attrs = self.restore_fields(data) attrs = self.restore_fields(data, files)
attrs = self.perform_validation(attrs) attrs = self.perform_validation(attrs)
else: else:
self._errors['non_field_errors'] = ['No input provided'] self._errors['non_field_errors'] = ['No input provided']
...@@ -288,7 +290,7 @@ class BaseSerializer(Field): ...@@ -288,7 +290,7 @@ class BaseSerializer(Field):
setting self.object if no errors occurred. setting self.object if no errors occurred.
""" """
if self._errors is None: if self._errors is None:
obj = self.from_native(self.init_data) obj = self.from_native(self.init_data, self.init_files)
if not self._errors: if not self._errors:
self.object = obj self.object = obj
return self._errors return self._errors
...@@ -440,6 +442,8 @@ class ModelSerializer(Serializer): ...@@ -440,6 +442,8 @@ class ModelSerializer(Serializer):
models.TextField: CharField, models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField, models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField, models.BooleanField: BooleanField,
models.FileField: FileField,
models.ImageField: ImageField,
} }
try: try:
return field_mapping[model_field.__class__](**kwargs) return field_mapping[model_field.__class__](**kwargs)
......
# from django.test import TestCase import StringIO
# from django import forms import datetime
# from django.test.client import RequestFactory from django.test import TestCase
# from rest_framework.views import View
# from rest_framework.response import Response
# import StringIO from rest_framework import serializers
# class UploadFilesTests(TestCase): class UploadedFile(object):
# """Check uploading of files""" def __init__(self, file, created=None):
# def setUp(self): self.file = file
# self.factory = RequestFactory() self.created = created or datetime.datetime.now()
# def test_upload_file(self):
# class FileForm(forms.Form): class UploadedFileSerializer(serializers.Serializer):
# file = forms.FileField() file = serializers.FileField()
created = serializers.DateTimeField()
# class MockView(View): def restore_object(self, attrs, instance=None):
# permissions = () if instance:
# form = FileForm instance.file = attrs['file']
instance.created = attrs['created']
return instance
return UploadedFile(**attrs)
# def post(self, request, *args, **kwargs):
# return Response({'FILE_NAME': self.CONTENT['file'].name,
# 'FILE_CONTENT': self.CONTENT['file'].read()})
# file = StringIO.StringIO('stuff') class FileSerializerTests(TestCase):
# file.name = 'stuff.txt'
# request = self.factory.post('/', {'file': file}) def test_create(self):
# view = MockView.as_view() now = datetime.datetime.now()
# response = view(request) file = StringIO.StringIO('stuff')
# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"}) file.name = 'stuff.txt'
file.size = file.len
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.object.created, uploaded_file.created)
self.assertEquals(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment