Commit f2852811 by Tom Christie

Getting tests passing

parent ec096a1c
...@@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer): ...@@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
attrs['user'] = user
return attrs
else: else:
msg = _('Unable to login with provided credentials.') msg = _('Unable to login with provided credentials.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
else: else:
msg = _('Must include "username" and "password"') msg = _('Must include "username" and "password"')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
attrs['user'] = user
return attrs
...@@ -18,7 +18,8 @@ class ObtainAuthToken(APIView): ...@@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):
def post(self, request): def post(self, request):
serializer = self.serializer_class(data=request.DATA) serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid(): if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user']) user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key}) return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
......
from rest_framework.utils import html from rest_framework.utils import html
import inspect
class empty: class empty:
...@@ -11,6 +12,22 @@ class empty: ...@@ -11,6 +12,22 @@ class empty:
pass pass
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
function = inspect.isfunction(obj)
method = inspect.ismethod(obj)
if not (function or method):
return False
args, _, _, defaults = inspect.getargspec(obj)
len_args = len(args) if function else len(args) - 1
len_defaults = len(defaults) if defaults else 0
return len_args <= len_defaults
def get_attribute(instance, attrs): def get_attribute(instance, attrs):
""" """
Similar to Python's built in `getattr(instance, attr)`, Similar to Python's built in `getattr(instance, attr)`,
...@@ -98,6 +115,7 @@ class Field(object): ...@@ -98,6 +115,7 @@ class Field(object):
self.field_name = field_name self.field_name = field_name
self.parent = parent self.parent = parent
self.root = root self.root = root
self.context = parent.context
# `self.label` should deafult to being based on the field name. # `self.label` should deafult to being based on the field name.
if self.label is None: if self.label is None:
...@@ -297,25 +315,55 @@ class IntegerField(Field): ...@@ -297,25 +315,55 @@ class IntegerField(Field):
self.fail('invalid_integer') self.fail('invalid_integer')
return data return data
def to_primative(self, value):
if value is None:
return None
return int(value)
class EmailField(CharField): class EmailField(CharField):
pass # TODO pass # TODO
class URLField(CharField):
pass # TODO
class RegexField(CharField): class RegexField(CharField):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.regex = kwargs.pop('regex') self.regex = kwargs.pop('regex')
super(CharField, self).__init__(**kwargs) super(CharField, self).__init__(**kwargs)
class DateField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateField, self).__init__(**kwargs)
class TimeField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(TimeField, self).__init__(**kwargs)
class DateTimeField(CharField): class DateTimeField(CharField):
pass # TODO def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateTimeField, self).__init__(**kwargs)
class FileField(Field): class FileField(Field):
pass # TODO pass # TODO
class ReadOnlyField(Field):
def to_primative(self, value):
if is_simple_callable(value):
return value()
return value
class MethodField(Field): class MethodField(Field):
def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs['source'] = '*' kwargs['source'] = '*'
......
...@@ -13,23 +13,6 @@ from rest_framework.request import clone_request ...@@ -13,23 +13,6 @@ from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
def _get_validation_exclusions(obj, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
For use when performing full_clean on a model instance,
so we only clean the required fields.
"""
if lookup_field == 'pk':
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
lookup_field = pk_field.name
return [field.name for field in obj._meta.fields if field.name != lookup_field]
class CreateModelMixin(object): class CreateModelMixin(object):
""" """
Create a model instance. Create a model instance.
...@@ -92,15 +75,14 @@ class UpdateModelMixin(object): ...@@ -92,15 +75,14 @@ class UpdateModelMixin(object):
if not serializer.is_valid(): if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
if self.object is None:
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg] lookup_value = self.kwargs[lookup_url_kwarg]
extras = {self.lookup_field: lookup_value} extras = {self.lookup_field: lookup_value}
if self.object is None:
self.object = serializer.save(extras=extras) self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
self.object = serializer.save(extras=extras) self.object = serializer.save()
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
...@@ -122,21 +104,6 @@ class UpdateModelMixin(object): ...@@ -122,21 +104,6 @@ class UpdateModelMixin(object):
# return a 404 response. # return a 404 response.
raise raise
def pre_save(self, obj):
"""
Set any attributes on the object that are implicit in the request.
"""
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg]
setattr(obj, self.lookup_field, lookup_value)
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
exclude = _get_validation_exclusions(obj, self.lookup_field)
obj.full_clean(exclude)
class DestroyModelMixin(object): class DestroyModelMixin(object):
""" """
......
...@@ -13,7 +13,7 @@ class NextPageField(serializers.Field): ...@@ -13,7 +13,7 @@ class NextPageField(serializers.Field):
""" """
page_field = 'page' page_field = 'page'
def to_native(self, value): def to_primative(self, value):
if not value.has_next(): if not value.has_next():
return None return None
page = value.next_page_number() page = value.next_page_number()
...@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field): ...@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):
""" """
page_field = 'page' page_field = 'page'
def to_native(self, value): def to_primative(self, value):
if not value.has_previous(): if not value.has_previous():
return None return None
page = value.previous_page_number() page = value.previous_page_number()
...@@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field): ...@@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source) super(DefaultObjectSerializer, self).__init__(source=source)
# class PaginationSerializerOptions(serializers.SerializerOptions):
# """
# An object that stores the options that may be provided to a
# pagination serializer by using the inner `Meta` class.
# Accessible on the instance as `serializer.opts`.
# """
# def __init__(self, meta):
# super(PaginationSerializerOptions, self).__init__(meta)
# self.object_serializer_class = getattr(meta, 'object_serializer_class',
# DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer): class BasePaginationSerializer(serializers.Serializer):
""" """
A base class for pagination serializers to inherit from, A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy. to make implementing custom serializers more easy.
""" """
# _options_class = PaginationSerializerOptions
results_field = 'results' results_field = 'results'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer): ...@@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer):
""" """
super(BasePaginationSerializer, self).__init__(*args, **kwargs) super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field results_field = self.results_field
object_serializer = self.opts.object_serializer_class try:
object_serializer = self.Meta.object_serializer_class
if 'context' in kwargs: except AttributeError:
context_kwarg = {'context': kwargs['context']} object_serializer = DefaultObjectSerializer
else:
context_kwarg = {} self.fields[results_field] = serializers.ListSerializer(
child=object_serializer(),
self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) source='object_list'
)
self.fields[results_field].bind(results_field, self, self) # TODO: Support automatic binding
class PaginationSerializer(BasePaginationSerializer): class PaginationSerializer(BasePaginationSerializer):
......
...@@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField): ...@@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField):
try: try:
http_prefix = value.startswith(('http:', 'https:')) http_prefix = value.startswith(('http:', 'https:'))
except AttributeError: except AttributeError:
self.fail('incorrect_type', type(value).__name__) self.fail('incorrect_type', data_type=type(value).__name__)
if http_prefix: if http_prefix:
# If needed convert absolute URLs to relative path # If needed convert absolute URLs to relative path
......
...@@ -142,7 +142,7 @@ class Serializer(BaseSerializer): ...@@ -142,7 +142,7 @@ class Serializer(BaseSerializer):
return super(Serializer, cls).__new__(cls) return super(Serializer, cls).__new__(cls)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs.pop('context', None) self.context = kwargs.pop('context', {})
kwargs.pop('partial', None) kwargs.pop('partial', None)
kwargs.pop('many', False) kwargs.pop('many', False)
...@@ -202,7 +202,7 @@ class Serializer(BaseSerializer): ...@@ -202,7 +202,7 @@ class Serializer(BaseSerializer):
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
return ret return self.validate(ret)
def to_primative(self, instance): def to_primative(self, instance):
""" """
...@@ -217,6 +217,9 @@ class Serializer(BaseSerializer): ...@@ -217,6 +217,9 @@ class Serializer(BaseSerializer):
return ret return ret
def validate(self, attrs):
return attrs
def __iter__(self): def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {} errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values(): for field in self.fields.values():
...@@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer): ...@@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child)) self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.' assert self.child is not None, '`child` is a required argument.'
self.context = kwargs.pop('context', {})
kwargs.pop('context', None)
kwargs.pop('partial', None) kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs) super(ListSerializer, self).__init__(*args, **kwargs)
...@@ -316,19 +318,19 @@ class ModelSerializer(Serializer): ...@@ -316,19 +318,19 @@ class ModelSerializer(Serializer):
models.PositiveIntegerField: IntegerField, models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField, models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField,
# models.DateTimeField: DateTimeField, models.DateTimeField: DateTimeField,
# models.DateField: DateField, models.DateField: DateField,
# models.TimeField: TimeField, models.TimeField: TimeField,
# models.DecimalField: DecimalField, # models.DecimalField: DecimalField,
# models.EmailField: EmailField, models.EmailField: EmailField,
models.CharField: CharField, models.CharField: CharField,
# models.URLField: URLField, models.URLField: URLField,
# models.SlugField: SlugField, # models.SlugField: SlugField,
models.TextField: CharField, models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField, models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField, models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField, models.NullBooleanField: BooleanField,
# models.FileField: FileField, models.FileField: FileField,
# models.ImageField: ImageField, # models.ImageField: ImageField,
} }
...@@ -338,6 +340,15 @@ class ModelSerializer(Serializer): ...@@ -338,6 +340,15 @@ class ModelSerializer(Serializer):
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs) super(ModelSerializer, self).__init__(*args, **kwargs)
def create(self):
ModelClass = self.opts.model
return ModelClass.objects.create(**self.validated_data)
def update(self, obj):
for attr, value in self.validated_data.items():
setattr(obj, attr, value)
obj.save()
def get_fields(self): def get_fields(self):
# Get the explicitly declared fields. # Get the explicitly declared fields.
fields = copy.deepcopy(self.base_fields) fields = copy.deepcopy(self.base_fields)
...@@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): ...@@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions _options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail' _default_view_name = '%(model_name)s-detail'
# _hyperlink_field_class = HyperlinkedRelatedField _hyperlink_field_class = HyperlinkedRelatedField
# _hyperlink_identify_field_class = HyperlinkedIdentityField _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self): def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields() fields = super(HyperlinkedModelSerializer, self).get_default_fields()
...@@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer): ...@@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None: if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model) self.opts.view_name = self._get_default_view_name(self.opts.model)
# if self.opts.url_field_name not in fields: if self.opts.url_field_name not in fields:
# url_field = self._hyperlink_identify_field_class( url_field = self._hyperlink_identify_field_class(
# view_name=self.opts.view_name, view_name=self.opts.view_name,
# lookup_field=self.opts.lookup_field lookup_field=self.opts.lookup_field
# ) )
# ret = self._dict_class() ret = fields.__class__()
# ret[self.opts.url_field_name] = url_field ret[self.opts.url_field_name] = url_field
# ret.update(fields) ret.update(fields)
# fields = ret fields = ret
return fields return fields
......
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.test import TestCase # from django.test import TestCase
from django.utils import six # from django.utils import six
from rest_framework import serializers # from rest_framework import serializers
from rest_framework.compat import BytesIO # from rest_framework.compat import BytesIO
import datetime # import datetime
class UploadedFile(object): # class UploadedFile(object):
def __init__(self, file=None, created=None): # def __init__(self, file=None, created=None):
self.file = file # self.file = file
self.created = created or datetime.datetime.now() # self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer): # class UploadedFileSerializer(serializers.Serializer):
file = serializers.FileField(required=False) # file = serializers.FileField(required=False)
created = serializers.DateTimeField() # created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
if instance: # if instance:
instance.file = attrs['file'] # instance.file = attrs['file']
instance.created = attrs['created'] # instance.created = attrs['created']
return instance # return instance
return UploadedFile(**attrs) # return UploadedFile(**attrs)
class FileSerializerTests(TestCase): # class FileSerializerTests(TestCase):
def test_create(self): # def test_create(self):
now = datetime.datetime.now() # now = datetime.datetime.now()
file = BytesIO(six.b('stuff')) # file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt' # file.name = 'stuff.txt'
file.size = len(file.getvalue()) # file.size = len(file.getvalue())
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) # serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now) # uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, uploaded_file.created) # self.assertEqual(serializer.object.created, uploaded_file.created)
self.assertEqual(serializer.object.file, uploaded_file.file) # self.assertEqual(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file) # self.assertFalse(serializer.object is uploaded_file)
def test_creation_failure(self): # def test_creation_failure(self):
""" # """
Passing files=None should result in an ValidationError # Passing files=None should result in an ValidationError
Regression test for: # Regression test for:
https://github.com/tomchristie/django-rest-framework/issues/542 # https://github.com/tomchristie/django-rest-framework/issues/542
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now}) # serializer = UploadedFileSerializer(data={'created': now})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, now) # self.assertEqual(serializer.object.created, now)
self.assertIsNone(serializer.object.file) # self.assertIsNone(serializer.object.file)
def test_remove_with_empty_string(self): # def test_remove_with_empty_string(self):
""" # """
Passing empty string as data should cause file to be removed # Passing empty string as data should cause file to be removed
Test for: # Test for:
https://github.com/tomchristie/django-rest-framework/issues/937 # https://github.com/tomchristie/django-rest-framework/issues/937
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
file = BytesIO(six.b('stuff')) # file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt' # file.name = 'stuff.txt'
file.size = len(file.getvalue()) # file.size = len(file.getvalue())
uploaded_file = UploadedFile(file=file, created=now) # uploaded_file = UploadedFile(file=file, created=now)
serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) # serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, uploaded_file.created) # self.assertEqual(serializer.object.created, uploaded_file.created)
self.assertIsNone(serializer.object.file) # self.assertIsNone(serializer.object.file)
def test_validation_error_with_non_file(self): # def test_validation_error_with_non_file(self):
""" # """
Passing non-files should raise a validation error. # Passing non-files should raise a validation error.
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
errmsg = 'No file was submitted. Check the encoding type on the form.' # errmsg = 'No file was submitted. Check the encoding type on the form.'
serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) # serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'file': [errmsg]}) # self.assertEqual(serializer.errors, {'file': [errmsg]})
def test_validation_with_no_data(self): # def test_validation_with_no_data(self):
""" # """
Validation should still function when no data dictionary is provided. # Validation should still function when no data dictionary is provided.
""" # """
uploaded_file = BytesIO(six.b('stuff')) # uploaded_file = BytesIO(six.b('stuff'))
uploaded_file.name = 'stuff.txt' # uploaded_file.name = 'stuff.txt'
uploaded_file.size = len(uploaded_file.getvalue()) # uploaded_file.size = len(uploaded_file.getvalue())
serializer = UploadedFileSerializer(files={'file': uploaded_file}) # serializer = UploadedFileSerializer(files={'file': uploaded_file})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.contrib.contenttypes.models import ContentType # from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey # from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from rest_framework.compat import python_2_unicode_compatible # from rest_framework.compat import python_2_unicode_compatible
@python_2_unicode_compatible # @python_2_unicode_compatible
class Tag(models.Model): # class Tag(models.Model):
""" # """
Tags have a descriptive slug, and are attached to an arbitrary object. # Tags have a descriptive slug, and are attached to an arbitrary object.
""" # """
tag = models.SlugField() # tag = models.SlugField()
content_type = models.ForeignKey(ContentType) # content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField() # object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id') # tagged_item = GenericForeignKey('content_type', 'object_id')
def __str__(self): # def __str__(self):
return self.tag # return self.tag
@python_2_unicode_compatible # @python_2_unicode_compatible
class Bookmark(models.Model): # class Bookmark(models.Model):
""" # """
A URL bookmark that may have multiple tags attached. # A URL bookmark that may have multiple tags attached.
""" # """
url = models.URLField() # url = models.URLField()
tags = GenericRelation(Tag) # tags = GenericRelation(Tag)
def __str__(self): # def __str__(self):
return 'Bookmark: %s' % self.url # return 'Bookmark: %s' % self.url
@python_2_unicode_compatible # @python_2_unicode_compatible
class Note(models.Model): # class Note(models.Model):
""" # """
A textual note that may have multiple tags attached. # A textual note that may have multiple tags attached.
""" # """
text = models.TextField() # text = models.TextField()
tags = GenericRelation(Tag) # tags = GenericRelation(Tag)
def __str__(self): # def __str__(self):
return 'Note: %s' % self.text # return 'Note: %s' % self.text
class TestGenericRelations(TestCase): # class TestGenericRelations(TestCase):
def setUp(self): # def setUp(self):
self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') # self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
Tag.objects.create(tagged_item=self.bookmark, tag='django') # Tag.objects.create(tagged_item=self.bookmark, tag='django')
Tag.objects.create(tagged_item=self.bookmark, tag='python') # Tag.objects.create(tagged_item=self.bookmark, tag='python')
self.note = Note.objects.create(text='Remember the milk') # self.note = Note.objects.create(text='Remember the milk')
Tag.objects.create(tagged_item=self.note, tag='reminder') # Tag.objects.create(tagged_item=self.note, tag='reminder')
def test_generic_relation(self): # def test_generic_relation(self):
""" # """
Test a relationship that spans a GenericRelation field. # Test a relationship that spans a GenericRelation field.
IE. A reverse generic relationship. # IE. A reverse generic relationship.
""" # """
class BookmarkSerializer(serializers.ModelSerializer): # class BookmarkSerializer(serializers.ModelSerializer):
tags = serializers.RelatedField(many=True) # tags = serializers.RelatedField(many=True)
class Meta: # class Meta:
model = Bookmark # model = Bookmark
exclude = ('id',) # exclude = ('id',)
serializer = BookmarkSerializer(self.bookmark) # serializer = BookmarkSerializer(self.bookmark)
expected = { # expected = {
'tags': ['django', 'python'], # 'tags': ['django', 'python'],
'url': 'https://www.djangoproject.com/' # 'url': 'https://www.djangoproject.com/'
} # }
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_generic_nested_relation(self): # def test_generic_nested_relation(self):
""" # """
Test saving a GenericRelation field via a nested serializer. # Test saving a GenericRelation field via a nested serializer.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('content_type', 'object_id') # exclude = ('content_type', 'object_id')
class BookmarkSerializer(serializers.ModelSerializer): # class BookmarkSerializer(serializers.ModelSerializer):
tags = TagSerializer(many=True) # tags = TagSerializer(many=True)
class Meta: # class Meta:
model = Bookmark # model = Bookmark
exclude = ('id',) # exclude = ('id',)
data = { # data = {
'url': 'https://docs.djangoproject.com/', # 'url': 'https://docs.djangoproject.com/',
'tags': [ # 'tags': [
{'tag': 'contenttypes'}, # {'tag': 'contenttypes'},
{'tag': 'genericrelations'}, # {'tag': 'genericrelations'},
] # ]
} # }
serializer = BookmarkSerializer(data=data) # serializer = BookmarkSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
self.assertEqual(serializer.object.tags.count(), 2) # self.assertEqual(serializer.object.tags.count(), 2)
def test_generic_fk(self): # def test_generic_fk(self):
""" # """
Test a relationship that spans a GenericForeignKey field. # Test a relationship that spans a GenericForeignKey field.
IE. A forward generic relationship. # IE. A forward generic relationship.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
tagged_item = serializers.RelatedField() # tagged_item = serializers.RelatedField()
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('id', 'content_type', 'object_id') # exclude = ('id', 'content_type', 'object_id')
serializer = TagSerializer(Tag.objects.all(), many=True) # serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [ # expected = [
{ # {
'tag': 'django', # 'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' # 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, # },
{ # {
'tag': 'python', # 'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' # 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, # },
{ # {
'tag': 'reminder', # 'tag': 'reminder',
'tagged_item': 'Note: Remember the milk' # 'tagged_item': 'Note: Remember the milk'
} # }
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_restore_object_generic_fk(self): # def test_restore_object_generic_fk(self):
""" # """
Ensure an object with a generic foreign key can be restored. # Ensure an object with a generic foreign key can be restored.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('content_type', 'object_id') # exclude = ('content_type', 'object_id')
serializer = TagSerializer() # serializer = TagSerializer()
bookmark = Bookmark(url='http://example.com') # bookmark = Bookmark(url='http://example.com')
attrs = {'tagged_item': bookmark, 'tag': 'example'} # attrs = {'tagged_item': bookmark, 'tag': 'example'}
tag = serializer.restore_object(attrs) # tag = serializer.restore_object(attrs)
self.assertEqual(tag.tagged_item, bookmark) # self.assertEqual(tag.tagged_item, bookmark)
from django.core.urlresolvers import reverse # from django.core.urlresolvers import reverse
from django.conf.urls import patterns, url # from django.conf.urls import patterns, url
from rest_framework import serializers, generics # from rest_framework import serializers, generics
from rest_framework.test import APITestCase # from rest_framework.test import APITestCase
from tests.models import NullableForeignKeySource # from tests.models import NullableForeignKeySource
class NullableFKSourceSerializer(serializers.ModelSerializer): # class NullableFKSourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = NullableForeignKeySource # model = NullableForeignKeySource
class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): # class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer_class = NullableFKSourceSerializer # serializer_class = NullableFKSourceSerializer
urlpatterns = patterns( # urlpatterns = patterns(
'', # '',
url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), # url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
) # )
class NullableForeignKeyTests(APITestCase): # class NullableForeignKeyTests(APITestCase):
""" # """
DRF should be able to handle nullable foreign keys when a test # DRF should be able to handle nullable foreign keys when a test
Client POST/PUT request is made with its own serialized object. # Client POST/PUT request is made with its own serialized object.
""" # """
urls = 'tests.test_nullable_fields' # urls = 'tests.test_nullable_fields'
def test_updating_object_with_null_fk(self): # def test_updating_object_with_null_fk(self):
obj = NullableForeignKeySource(name='example', target=None) # obj = NullableForeignKeySource(name='example', target=None)
obj.save() # obj.save()
serialized_data = NullableFKSourceSerializer(obj).data # serialized_data = NullableFKSourceSerializer(obj).data
response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) # response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
self.assertEqual(response.data, serialized_data) # self.assertEqual(response.data, serialized_data)
...@@ -391,10 +391,10 @@ class CustomField(serializers.Field): ...@@ -391,10 +391,10 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer): class BasicModelSerializer(serializers.Serializer):
text = CustomField() text = CustomField()
def __init__(self, *args, **kwargs): def to_native(self, value):
super(BasicModelSerializer, self).__init__(*args, **kwargs)
if 'view' not in self.context: if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into serializer init") raise RuntimeError("context isn't getting passed into serializer")
return super(BasicSerializer, self).to_native(value)
class TestContextPassedToCustomField(TestCase): class TestContextPassedToCustomField(TestCase):
...@@ -423,7 +423,7 @@ class LinksSerializer(serializers.Serializer): ...@@ -423,7 +423,7 @@ class LinksSerializer(serializers.Serializer):
class CustomPaginationSerializer(pagination.BasePaginationSerializer): class CustomPaginationSerializer(pagination.BasePaginationSerializer):
links = LinksSerializer(source='*') # Takes the page object as the source links = LinksSerializer(source='*') # Takes the page object as the source
total_results = serializers.Field(source='paginator.count') total_results = serializers.ReadOnlyField(source='paginator.count')
results_field = 'objects' results_field = 'objects'
......
...@@ -108,59 +108,59 @@ class ModelPermissionsIntegrationTests(TestCase): ...@@ -108,59 +108,59 @@ class ModelPermissionsIntegrationTests(TestCase):
response = instance_view(request, pk='2') response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self): # def test_options_permitted(self):
request = factory.options( # request = factory.options(
'/', # '/',
HTTP_AUTHORIZATION=self.permitted_credentials # HTTP_AUTHORIZATION=self.permitted_credentials
) # )
response = root_view(request, pk='1') # response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) # self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['POST']) # self.assertEqual(list(response.data['actions'].keys()), ['POST'])
request = factory.options( # request = factory.options(
'/1', # '/1',
HTTP_AUTHORIZATION=self.permitted_credentials # HTTP_AUTHORIZATION=self.permitted_credentials
) # )
response = instance_view(request, pk='1') # response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) # self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) # self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_options_disallowed(self): # def test_options_disallowed(self):
request = factory.options( # request = factory.options(
'/', # '/',
HTTP_AUTHORIZATION=self.disallowed_credentials # HTTP_AUTHORIZATION=self.disallowed_credentials
) # )
response = root_view(request, pk='1') # response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) # self.assertNotIn('actions', response.data)
request = factory.options( # request = factory.options(
'/1', # '/1',
HTTP_AUTHORIZATION=self.disallowed_credentials # HTTP_AUTHORIZATION=self.disallowed_credentials
) # )
response = instance_view(request, pk='1') # response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) # self.assertNotIn('actions', response.data)
def test_options_updateonly(self): # def test_options_updateonly(self):
request = factory.options( # request = factory.options(
'/', # '/',
HTTP_AUTHORIZATION=self.updateonly_credentials # HTTP_AUTHORIZATION=self.updateonly_credentials
) # )
response = root_view(request, pk='1') # response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) # self.assertNotIn('actions', response.data)
request = factory.options( # request = factory.options(
'/1', # '/1',
HTTP_AUTHORIZATION=self.updateonly_credentials # HTTP_AUTHORIZATION=self.updateonly_credentials
) # )
response = instance_view(request, pk='1') # response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) # self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) # self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
class BasicPermModel(models.Model): class BasicPermModel(models.Model):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment