Commit 5b7e4af0 by Tom Christie

get_base_field() refactor

parent c0155fd9
......@@ -80,10 +80,6 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value
def field_name_to_label(field_name):
return field_name.replace('_', ' ').capitalize()
class SkipField(Exception):
pass
......@@ -162,7 +158,7 @@ class Field(object):
# `self.label` should deafult to being based on the field name.
if self.label is None:
self.label = field_name_to_label(self.field_name)
self.label = field_name.replace('_', ' ').capitalize()
# self.source should default to being the same as the field name.
if self.source is None:
......
......@@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
}
def __init__(self, view_name, **kwargs):
def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
self.view_name = view_name
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
......@@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
URL of relationships to other objects.
"""
def __init__(self, view_name, **kwargs):
def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
......@@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField):
'invalid': _('Invalid value.'),
}
def __init__(self, slug_field, **kwargs):
def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs)
......
"""
Helper functions for mapping model fields to a dictionary of default
keyword arguments that should be used for their equivelent serializer fields.
"""
from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import clean_manytomany_helptext
import inspect
def lookup_class(mapping, instance):
"""
Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
for cls in inspect.getmro(instance.__class__):
if cls in mapping:
return mapping[cls]
raise KeyError('Class %s not found in lookup.', cls.__name__)
def needs_label(model_field, field_name):
"""
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
default_label = field_name.replace('_', ' ').capitalize()
return capfirst(model_field.verbose_name) != default_label
def get_detail_view_name(model):
"""
Given a model class, return the view name to use for URL relationships
that refer to instances of the model.
"""
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}
def get_field_kwargs(field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.flatchoices:
# If this model field contains choices, then return now,
# any further keyword arguments are not valid.
kwargs['choices'] = model_field.flatchoices
return kwargs
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.BooleanField):
# models.BooleanField has `blank=True`, but *is* actually
# required *unless* a default is provided.
# Also note that Django<1.6 uses `default=False` for
# models.BooleanField, but Django>=1.6 uses `default=None`.
kwargs.pop('required', None)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
# The following will only be used by ModelField classes.
# Gets removed for everything else.
kwargs['model_field'] = model_field
return kwargs
def get_relation_kwargs(field_name, relation_info):
"""
Creates a default instance of a flat relational field.
"""
model_field, related_model, to_many, has_through_model = relation_info
kwargs = {
'queryset': related_model._default_manager,
'view_name': get_detail_view_name(related_model)
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
return kwargs
def get_nested_relation_kwargs(relation_info):
kwargs = {'read_only': True}
if relation_info.to_many:
kwargs['many'] = True
return kwargs
def get_url_kwargs(model_field):
return {
'view_name': get_detail_view_name(model_field)
}
"""
Helper functions for returning the field information that is associated
Helper function for returning the field information that is associated
with a model class. This includes returning all the forward and reverse
relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import namedtuple
from django.db import models
......@@ -9,8 +11,22 @@ from django.utils import six
from django.utils.datastructures import SortedDict
import inspect
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance
'fields', # Dict of field name -> model field instance
'forward_relations', # Dict of field name -> RelationInfo
'reverse_relations', # Dict of field name -> RelationInfo
'fields_and_pk', # Shortcut for 'pk' + 'fields'
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
])
RelationInfo = namedtuple('RelationInfo', [
'model_field',
'related',
'to_many',
'has_through_model'
])
def _resolve_model(obj):
......@@ -55,7 +71,7 @@ def get_field_info(model):
forward_relations = SortedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
field=field,
model_field=field,
related=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
......@@ -64,7 +80,7 @@ def get_field_info(model):
# Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
field=field,
model_field=field,
related=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
......@@ -77,7 +93,7 @@ def get_field_info(model):
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
field=None,
model_field=None,
related=relation.model,
to_many=relation.field.rel.multiple,
has_through_model=False
......@@ -87,7 +103,7 @@ def get_field_info(model):
for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
field=None,
model_field=None,
related=relation.model,
to_many=True,
has_through_model=(
......@@ -96,4 +112,18 @@ def get_field_info(model):
)
)
return FieldInfo(pk, fields, forward_relations, reverse_relations)
# Shortcut that merges both regular fields and the pk,
# for simplifying regular field lookup.
fields_and_pk = SortedDict()
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
# Shortcut that merges both forward and reverse relationships
relations = SortedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
)
return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
from __future__ import unicode_literals
from django.db import models
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
def foobar():
......@@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source')
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
......@@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = NestedModelSerializer(read_only=True):
foreign_key = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
one_to_one = NestedModelSerializer(read_only=True):
one_to_one = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True):
many_to_many = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
through = NestedModelSerializer(many=True, read_only=True):
through = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
""")
......@@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent("""
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = NestedModelSerializer(read_only=True):
foreign_key = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
name = CharField(max_length=100)
one_to_one = NestedModelSerializer(read_only=True):
one_to_one = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True):
many_to_many = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
name = CharField(max_length=100)
through = NestedModelSerializer(many=True, read_only=True):
through = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
""")
......
......@@ -2,11 +2,12 @@ from __future__ import unicode_literals
from django.conf.urls import patterns, url, include
from django.test import TestCase
from django.utils import six
from tests.models import BasicModel, BasicModelSerializer
from tests.models import BasicModel
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework import generics
from rest_framework import routers
from rest_framework import serializers
from rest_framework import status
from rest_framework.renderers import (
BaseRenderer,
......@@ -17,6 +18,12 @@ from rest_framework import viewsets
from rest_framework.settings import api_settings
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
class MockPickleRenderer(BaseRenderer):
media_type = 'application/pickle'
......
......@@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase):
def setUp(self):
class NoteSerializer(serializers.HyperlinkedModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
class Meta:
model = RouterTestModel
lookup_field = 'uuid'
fields = ('url', 'uuid', 'text')
class NoteViewSet(viewsets.ModelViewSet):
......@@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase):
serializer_class = NoteSerializer
lookup_field = 'uuid'
RouterTestModel.objects.create(uuid='123', text='foo bar')
self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet)
......@@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase):
url(r'^', include(self.router.urls)),
)
RouterTestModel.objects.create(uuid='123', text='foo bar')
def test_custom_lookup_field_route(self):
detail_route = self.router.urls[-1]
detail_url_pattern = detail_route.regex.pattern
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment