Commit 040bfcc0 by Tom Christie

NotImplemented stubs for Field, and DecimalField improvements

parent a7518719
...@@ -229,13 +229,13 @@ class Field(object): ...@@ -229,13 +229,13 @@ class Field(object):
""" """
Transform the *incoming* primative data into a native value. Transform the *incoming* primative data into a native value.
""" """
return data raise NotImplementedError('to_native() must be implemented.')
def to_primative(self, value): def to_primative(self, value):
""" """
Transform the *outgoing* native value into primative data. Transform the *outgoing* native value into primative data.
""" """
return value raise NotImplementedError('to_primative() must be implemented.')
def fail(self, key, **kwargs): def fail(self, key, **kwargs):
""" """
...@@ -429,9 +429,10 @@ class DecimalField(Field): ...@@ -429,9 +429,10 @@ class DecimalField(Field):
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
} }
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs): def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, **kwargs):
self.max_value, self.min_value = max_value, min_value self.max_digits = max_digits
self.max_digits, self.max_decimal_places = max_digits, decimal_places self.decimal_places = decimal_places
self.coerce_to_string = coerce_to_string
super(DecimalField, self).__init__(**kwargs) super(DecimalField, self).__init__(**kwargs)
if max_value is not None: if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value)) self.validators.append(validators.MaxValueValidator(max_value))
...@@ -478,12 +479,26 @@ class DecimalField(Field): ...@@ -478,12 +479,26 @@ class DecimalField(Field):
if self.max_digits is not None and digits > self.max_digits: if self.max_digits is not None and digits > self.max_digits:
self.fail('max_digits', max_digits=self.max_digits) self.fail('max_digits', max_digits=self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places: if self.decimal_places is not None and decimals > self.decimal_places:
self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places) self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
return value return value
def to_primative(self, value):
if not self.coerce_to_string:
return value
if isinstance(value, decimal.Decimal):
context = decimal.getcontext().copy()
context.prec = self.max_digits
quantized = value.quantize(
decimal.Decimal('.1') ** self.decimal_places,
context=context
)
return '{0:f}'.format(quantized)
return '%.*f' % (self.max_decimal_places, value)
# Date & time fields... # Date & time fields...
......
...@@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field): ...@@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page) return replace_query_param(url, self.page_field, page)
class DefaultObjectSerializer(serializers.Field): class DefaultObjectSerializer(serializers.ReadOnlyField):
""" """
If no object serializer is specified, then this serializer will be applied If no object serializer is specified, then this serializer will be applied
as the default. as the default.
...@@ -79,6 +79,6 @@ class PaginationSerializer(BasePaginationSerializer): ...@@ -79,6 +79,6 @@ class PaginationSerializer(BasePaginationSerializer):
""" """
A default implementation of a pagination serializer. A default implementation of a pagination serializer.
""" """
count = serializers.Field(source='paginator.count') count = serializers.ReadOnlyField(source='paginator.count')
next = NextPageField(source='*') next = NextPageField(source='*')
previous = PreviousPageField(source='*') previous = PreviousPageField(source='*')
...@@ -43,7 +43,7 @@ class JSONEncoder(json.JSONEncoder): ...@@ -43,7 +43,7 @@ class JSONEncoder(json.JSONEncoder):
elif isinstance(o, datetime.timedelta): elif isinstance(o, datetime.timedelta):
return str(o.total_seconds()) return str(o.total_seconds())
elif isinstance(o, decimal.Decimal): elif isinstance(o, decimal.Decimal):
return str(o) return float(o)
elif isinstance(o, QuerySet): elif isinstance(o, QuerySet):
return list(o) return list(o)
elif hasattr(o, 'tolist'): elif hasattr(o, 'tolist'):
......
...@@ -102,7 +102,7 @@ if django_filters: ...@@ -102,7 +102,7 @@ if django_filters:
class CommonFilteringTestCase(TestCase): class CommonFilteringTestCase(TestCase):
def _serialize_object(self, obj): def _serialize_object(self, obj):
return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date}
def setUp(self): def setUp(self):
""" """
...@@ -145,7 +145,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -145,7 +145,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'decimal': '%s' % search_decimal}) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
# Tests that the date filter works. # Tests that the date filter works.
...@@ -168,7 +168,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -168,7 +168,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'decimal': '%s' % search_decimal}) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filter not installed') @unittest.skipUnless(django_filters, 'django-filter not installed')
...@@ -201,7 +201,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -201,7 +201,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/', {'decimal': '%s' % search_decimal}) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal] expected_data = [f for f in self.data if Decimal(f['decimal']) < search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
# Tests that the date filter set with 'gt' in the filter class works. # Tests that the date filter set with 'gt' in the filter class works.
...@@ -230,7 +230,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -230,7 +230,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and expected_data = [f for f in self.data if f['date'] > search_date and
f['decimal'] < search_decimal] Decimal(f['decimal']) < search_decimal]
self.assertEqual(response.data, expected_data) self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filter not installed') @unittest.skipUnless(django_filters, 'django-filter not installed')
......
...@@ -135,7 +135,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -135,7 +135,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects self.objects = FilterableItem.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date}
for obj in self.objects.all() for obj in self.objects.all()
] ]
...@@ -381,7 +381,7 @@ class TestMaxPaginateByParam(TestCase): ...@@ -381,7 +381,7 @@ class TestMaxPaginateByParam(TestCase):
# Tests for context in pagination serializers # Tests for context in pagination serializers
class CustomField(serializers.Field): class CustomField(serializers.ReadOnlyField):
def to_native(self, value): def to_native(self, value):
if 'view' not in self.context: if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into custom field") raise RuntimeError("context isn't getting passed into custom field")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment