Commit ab40780d by Tom Christie

Tidy up lookup_class

parent 3318f75a
...@@ -317,17 +317,17 @@ class ModelSerializerOptions(object): ...@@ -317,17 +317,17 @@ class ModelSerializerOptions(object):
self.depth = getattr(meta, 'depth', 0) self.depth = getattr(meta, 'depth', 0)
def lookup_class(mapping, obj): def lookup_class(mapping, instance):
""" """
Takes a dictionary with classes as keys, and an object. Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value resolution order, and returns the first matching value
from the dictionary or None. from the dictionary or raises a KeyError if nothing matches.
""" """
return next( for cls in inspect.getmro(instance.__class__):
(mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), if cls in mapping:
None return mapping[cls]
) raise KeyError('Class %s not found in lookup.', cls.__name__)
class ModelSerializer(Serializer): class ModelSerializer(Serializer):
...@@ -341,6 +341,7 @@ class ModelSerializer(Serializer): ...@@ -341,6 +341,7 @@ class ModelSerializer(Serializer):
models.DateTimeField: DateTimeField, models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField, models.DecimalField: DecimalField,
models.EmailField: EmailField, models.EmailField: EmailField,
models.Field: ModelField,
models.FileField: FileField, models.FileField: FileField,
models.FloatField: FloatField, models.FloatField: FloatField,
models.ImageField: ImageField, models.ImageField: ImageField,
...@@ -484,6 +485,7 @@ class ModelSerializer(Serializer): ...@@ -484,6 +485,7 @@ class ModelSerializer(Serializer):
""" """
Creates a default instance of a basic non-relational field. Creates a default instance of a basic non-relational field.
""" """
serializer_cls = lookup_class(self.field_mapping, model_field)
kwargs = {} kwargs = {}
validator_kwarg = model_field.validators validator_kwarg = model_field.validators
...@@ -602,11 +604,10 @@ class ModelSerializer(Serializer): ...@@ -602,11 +604,10 @@ class ModelSerializer(Serializer):
if validator_kwarg: if validator_kwarg:
kwargs['validators'] = validator_kwarg kwargs['validators'] = validator_kwarg
cls = lookup_class(self.field_mapping, model_field) if issubclass(serializer_cls, ModelField):
if cls is None:
cls = ModelField
kwargs['model_field'] = model_field kwargs['model_field'] = model_field
return cls(**kwargs)
return serializer_cls(**kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment