Commit 5b7e4af0 by Tom Christie

get_base_field() refactor

parent c0155fd9
...@@ -80,10 +80,6 @@ def set_value(dictionary, keys, value): ...@@ -80,10 +80,6 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value dictionary[keys[-1]] = value
def field_name_to_label(field_name):
return field_name.replace('_', ' ').capitalize()
class SkipField(Exception): class SkipField(Exception):
pass pass
...@@ -162,7 +158,7 @@ class Field(object): ...@@ -162,7 +158,7 @@ class Field(object):
# `self.label` should deafult to being based on the field name. # `self.label` should deafult to being based on the field name.
if self.label is None: if self.label is None:
self.label = field_name_to_label(self.field_name) self.label = field_name.replace('_', ' ').capitalize()
# self.source should default to being the same as the field name. # self.source should default to being the same as the field name.
if self.source is None: if self.source is None:
......
...@@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField): ...@@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.', 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
} }
def __init__(self, view_name, **kwargs): def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
self.view_name = view_name self.view_name = view_name
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
...@@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField): ...@@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
URL of relationships to other objects. URL of relationships to other objects.
""" """
def __init__(self, view_name, **kwargs): def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True kwargs['read_only'] = True
kwargs['source'] = '*' kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
...@@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField): ...@@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField):
'invalid': _('Invalid value.'), 'invalid': _('Invalid value.'),
} }
def __init__(self, slug_field, **kwargs): def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs) super(SlugRelatedField, self).__init__(**kwargs)
......
"""
Helper functions for mapping model fields to a dictionary of default
keyword arguments that should be used for their equivelent serializer fields.
"""
from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import clean_manytomany_helptext
import inspect
def lookup_class(mapping, instance):
"""
Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
for cls in inspect.getmro(instance.__class__):
if cls in mapping:
return mapping[cls]
raise KeyError('Class %s not found in lookup.', cls.__name__)
def needs_label(model_field, field_name):
"""
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
default_label = field_name.replace('_', ' ').capitalize()
return capfirst(model_field.verbose_name) != default_label
def get_detail_view_name(model):
"""
Given a model class, return the view name to use for URL relationships
that refer to instances of the model.
"""
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}
def get_field_kwargs(field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.flatchoices:
# If this model field contains choices, then return now,
# any further keyword arguments are not valid.
kwargs['choices'] = model_field.flatchoices
return kwargs
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.BooleanField):
# models.BooleanField has `blank=True`, but *is* actually
# required *unless* a default is provided.
# Also note that Django<1.6 uses `default=False` for
# models.BooleanField, but Django>=1.6 uses `default=None`.
kwargs.pop('required', None)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
# The following will only be used by ModelField classes.
# Gets removed for everything else.
kwargs['model_field'] = model_field
return kwargs
def get_relation_kwargs(field_name, relation_info):
"""
Creates a default instance of a flat relational field.
"""
model_field, related_model, to_many, has_through_model = relation_info
kwargs = {
'queryset': related_model._default_manager,
'view_name': get_detail_view_name(related_model)
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
return kwargs
def get_nested_relation_kwargs(relation_info):
kwargs = {'read_only': True}
if relation_info.to_many:
kwargs['many'] = True
return kwargs
def get_url_kwargs(model_field):
return {
'view_name': get_detail_view_name(model_field)
}
""" """
Helper functions for returning the field information that is associated Helper function for returning the field information that is associated
with a model class. This includes returning all the forward and reverse with a model class. This includes returning all the forward and reverse
relationships and their associated metadata. relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
""" """
from collections import namedtuple from collections import namedtuple
from django.db import models from django.db import models
...@@ -9,8 +11,22 @@ from django.utils import six ...@@ -9,8 +11,22 @@ from django.utils import six
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
import inspect import inspect
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance
'fields', # Dict of field name -> model field instance
'forward_relations', # Dict of field name -> RelationInfo
'reverse_relations', # Dict of field name -> RelationInfo
'fields_and_pk', # Shortcut for 'pk' + 'fields'
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
])
RelationInfo = namedtuple('RelationInfo', [
'model_field',
'related',
'to_many',
'has_through_model'
])
def _resolve_model(obj): def _resolve_model(obj):
...@@ -55,7 +71,7 @@ def get_field_info(model): ...@@ -55,7 +71,7 @@ def get_field_info(model):
forward_relations = SortedDict() forward_relations = SortedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]: for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
field=field, model_field=field,
related=_resolve_model(field.rel.to), related=_resolve_model(field.rel.to),
to_many=False, to_many=False,
has_through_model=False has_through_model=False
...@@ -64,7 +80,7 @@ def get_field_info(model): ...@@ -64,7 +80,7 @@ def get_field_info(model):
# Deal with forward many-to-many relationships. # Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]: for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
field=field, model_field=field,
related=_resolve_model(field.rel.to), related=_resolve_model(field.rel.to),
to_many=True, to_many=True,
has_through_model=( has_through_model=(
...@@ -77,7 +93,7 @@ def get_field_info(model): ...@@ -77,7 +93,7 @@ def get_field_info(model):
for relation in opts.get_all_related_objects(): for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name() accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo( reverse_relations[accessor_name] = RelationInfo(
field=None, model_field=None,
related=relation.model, related=relation.model,
to_many=relation.field.rel.multiple, to_many=relation.field.rel.multiple,
has_through_model=False has_through_model=False
...@@ -87,7 +103,7 @@ def get_field_info(model): ...@@ -87,7 +103,7 @@ def get_field_info(model):
for relation in opts.get_all_related_many_to_many_objects(): for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name() accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo( reverse_relations[accessor_name] = RelationInfo(
field=None, model_field=None,
related=relation.model, related=relation.model,
to_many=True, to_many=True,
has_through_model=( has_through_model=(
...@@ -96,4 +112,18 @@ def get_field_info(model): ...@@ -96,4 +112,18 @@ def get_field_info(model):
) )
) )
return FieldInfo(pk, fields, forward_relations, reverse_relations) # Shortcut that merges both regular fields and the pk,
# for simplifying regular field lookup.
fields_and_pk = SortedDict()
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
# Shortcut that merges both forward and reverse relationships
relations = SortedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
)
return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
def foobar(): def foobar():
...@@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel): ...@@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True, target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source') related_name='nullable_source')
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
...@@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase): ...@@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
foreign_key = NestedModelSerializer(read_only=True): foreign_key = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
one_to_one = NestedModelSerializer(read_only=True): one_to_one = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True): many_to_many = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
through = NestedModelSerializer(many=True, read_only=True): through = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
...@@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase): ...@@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail') url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = NestedModelSerializer(read_only=True): foreign_key = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail') url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
one_to_one = NestedModelSerializer(read_only=True): one_to_one = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True): many_to_many = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail') url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
through = NestedModelSerializer(many=True, read_only=True): through = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
......
...@@ -2,11 +2,12 @@ from __future__ import unicode_literals ...@@ -2,11 +2,12 @@ from __future__ import unicode_literals
from django.conf.urls import patterns, url, include from django.conf.urls import patterns, url, include
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from tests.models import BasicModel, BasicModelSerializer from tests.models import BasicModel
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework import generics from rest_framework import generics
from rest_framework import routers from rest_framework import routers
from rest_framework import serializers
from rest_framework import status from rest_framework import status
from rest_framework.renderers import ( from rest_framework.renderers import (
BaseRenderer, BaseRenderer,
...@@ -17,6 +18,12 @@ from rest_framework import viewsets ...@@ -17,6 +18,12 @@ from rest_framework import viewsets
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
class MockPickleRenderer(BaseRenderer): class MockPickleRenderer(BaseRenderer):
media_type = 'application/pickle' media_type = 'application/pickle'
......
...@@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase): ...@@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase):
def setUp(self): def setUp(self):
class NoteSerializer(serializers.HyperlinkedModelSerializer): class NoteSerializer(serializers.HyperlinkedModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
class Meta: class Meta:
model = RouterTestModel model = RouterTestModel
lookup_field = 'uuid'
fields = ('url', 'uuid', 'text') fields = ('url', 'uuid', 'text')
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
...@@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase): ...@@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase):
serializer_class = NoteSerializer serializer_class = NoteSerializer
lookup_field = 'uuid' lookup_field = 'uuid'
RouterTestModel.objects.create(uuid='123', text='foo bar')
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes', NoteViewSet)
...@@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase): ...@@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase):
url(r'^', include(self.router.urls)), url(r'^', include(self.router.urls)),
) )
RouterTestModel.objects.create(uuid='123', text='foo bar')
def test_custom_lookup_field_route(self): def test_custom_lookup_field_route(self):
detail_route = self.router.urls[-1] detail_route = self.router.urls[-1]
detail_url_pattern = detail_route.regex.pattern detail_url_pattern = detail_route.regex.pattern
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment