Commit 24a68822 by Tom Christie

Merge pull request #1375 from linovia/feature/django_1_7

Django 1.7 compatibility
parents ee9864e0 3d7cb72e
...@@ -7,6 +7,7 @@ python: ...@@ -7,6 +7,7 @@ python:
- "3.3" - "3.3"
env: env:
- DJANGO="https://www.djangoproject.com/download/1.7a2/tarball/"
- DJANGO="django==1.6.2" - DJANGO="django==1.6.2"
- DJANGO="django==1.5.5" - DJANGO="django==1.5.5"
- DJANGO="django==1.4.10" - DJANGO="django==1.4.10"
...@@ -14,13 +15,15 @@ env: ...@@ -14,13 +15,15 @@ env:
install: install:
- pip install $DJANGO - pip install $DJANGO
- pip install defusedxml==0.3 - pip install defusedxml==0.3 Pillow
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.1; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.4; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-guardian==1.1.1; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-guardian==1.1.1; fi"
- "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi" - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi"
- "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.6; fi" - "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.7; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} == '3' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7a2/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:
...@@ -28,6 +31,8 @@ script: ...@@ -28,6 +31,8 @@ script:
matrix: matrix:
exclude: exclude:
- python: "2.6"
env: DJANGO="https://www.djangoproject.com/download/1.7a2/tarball/"
- python: "3.2" - python: "3.2"
env: DJANGO="django==1.4.10" env: DJANGO="django==1.4.10"
- python: "3.2" - python: "3.2"
......
...@@ -584,3 +584,23 @@ if six.PY3: ...@@ -584,3 +584,23 @@ if six.PY3:
else: else:
def is_non_str_iterable(obj): def is_non_str_iterable(obj):
return hasattr(obj, '__iter__') return hasattr(obj, '__iter__')
try:
from django.utils.encoding import python_2_unicode_compatible
except ImportError:
def python_2_unicode_compatible(klass):
"""
A decorator that defines __unicode__ and __str__ methods under Python 2.
Under Python 3 it does nothing.
To support Python 2 and 3 with a single code base, define a __str__ method
returning text and apply this decorator to the class.
"""
if '__str__' not in klass.__dict__:
raise ValueError("@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." %
klass.__name__)
klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
return klass
...@@ -26,6 +26,10 @@ def usage(): ...@@ -26,6 +26,10 @@ def usage():
def main(): def main():
try:
django.setup()
except AttributeError:
pass
TestRunner = get_runner(settings) TestRunner = get_runner(settings)
test_runner = TestRunner() test_runner = TestRunner()
......
...@@ -8,6 +8,7 @@ from django.conf import settings ...@@ -8,6 +8,7 @@ from django.conf import settings
from django.test.client import Client as DjangoClient from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler from django.test.client import ClientHandler
from django.test import testcases from django.test import testcases
from django.utils.http import urlencode
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.compat import force_bytes_or_smart_bytes, six
...@@ -71,6 +72,13 @@ class APIRequestFactory(DjangoRequestFactory): ...@@ -71,6 +72,13 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type return ret, content_type
def get(self, path, data=None, **extra):
r = {
'QUERY_STRING': urlencode(data or {}, doseq=True),
}
r.update(extra)
return self.generic('GET', path, **r)
def post(self, path, data=None, format=None, content_type=None, **extra): def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra) return self.generic('POST', path, data, content_type, **extra)
......
...@@ -168,3 +168,10 @@ class NullableOneToOneSource(RESTFrameworkModel): ...@@ -168,3 +168,10 @@ class NullableOneToOneSource(RESTFrameworkModel):
class BasicModelSerializer(serializers.ModelSerializer): class BasicModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = BasicModel model = BasicModel
# Models to test filters
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
...@@ -9,16 +9,11 @@ from rest_framework import generics, serializers, status, filters ...@@ -9,16 +9,11 @@ from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url from rest_framework.compat import django_filters, patterns, url
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel
from .models import FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
if django_filters: if django_filters:
# Basic filter on a list view. # Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView): class FilterFieldsRootView(generics.ListCreateAPIView):
...@@ -128,7 +123,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -128,7 +123,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works. # Tests that the decimal filter works.
search_decimal = Decimal('2.25') search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal'] == search_decimal]
...@@ -136,7 +131,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -136,7 +131,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter works. # Tests that the date filter works.
search_date = datetime.date(2012, 9, 22) search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date] expected_data = [f for f in self.data if f['date'] == search_date]
...@@ -151,7 +146,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -151,7 +146,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works. # Tests that the decimal filter works.
search_decimal = Decimal('2.25') search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal'] == search_decimal]
...@@ -184,7 +179,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -184,7 +179,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set with 'lt' in the filter class works. # Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal] expected_data = [f for f in self.data if f['decimal'] < search_decimal]
...@@ -192,7 +187,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -192,7 +187,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter set with 'gt' in the filter class works. # Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date] expected_data = [f for f in self.data if f['date'] > search_date]
...@@ -200,7 +195,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -200,7 +195,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the text filter set with 'icontains' in the filter class works. # Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff' search_text = 'ff'
request = factory.get('/?text=%s' % search_text) request = factory.get('/', {'text': '%s' % search_text})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()] expected_data = [f for f in self.data if search_text in f['text'].lower()]
...@@ -209,7 +204,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -209,7 +204,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that multiple filters works. # Tests that multiple filters works.
search_decimal = Decimal('5.25') search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) request = factory.get('/', {
'decimal': '%s' % (search_decimal,),
'date': '%s' % (search_date,)
})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and expected_data = [f for f in self.data if f['date'] > search_date and
...@@ -234,7 +232,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): ...@@ -234,7 +232,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
view = FilterFieldsRootView.as_view() view = FilterFieldsRootView.as_view()
search_integer = 10 search_integer = 10
request = factory.get('/?integer=%s' % search_integer) request = factory.get('/', {'integer': '%s' % search_integer})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -265,14 +263,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): ...@@ -265,14 +263,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set that should fail. # Tests that the decimal filter set that should fail.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
high_item = self.objects.filter(decimal__gt=search_decimal)[0] high_item = self.objects.filter(decimal__gt=search_decimal)[0]
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) response = self.client.get(
'{url}'.format(url=self._get_url(high_item)),
{'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Tests that the decimal filter set that should succeed. # Tests that the decimal filter set that should succeed.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
low_item = self.objects.filter(decimal__lt=search_decimal)[0] low_item = self.objects.filter(decimal__lt=search_decimal)[0]
low_item_data = self._serialize_object(low_item) low_item_data = self._serialize_object(low_item)
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) response = self.client.get(
'{url}'.format(url=self._get_url(low_item)),
{'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, low_item_data) self.assertEqual(response.data, low_item_data)
...@@ -281,7 +283,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): ...@@ -281,7 +283,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
valid_item_data = self._serialize_object(valid_item) valid_item_data = self._serialize_object(valid_item)
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) response = self.client.get(
'{url}'.format(url=self._get_url(valid_item)), {
'decimal': '{decimal}'.format(decimal=search_decimal),
'date': '{date}'.format(date=search_date)
})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data) self.assertEqual(response.data, valid_item_data)
...@@ -315,7 +321,7 @@ class SearchFilterTests(TestCase): ...@@ -315,7 +321,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', 'text') search_fields = ('title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=b') request = factory.get('/', {'search': 'b'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
...@@ -332,7 +338,7 @@ class SearchFilterTests(TestCase): ...@@ -332,7 +338,7 @@ class SearchFilterTests(TestCase):
search_fields = ('=title', 'text') search_fields = ('=title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=zzz') request = factory.get('/', {'search': 'zzz'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
...@@ -348,7 +354,7 @@ class SearchFilterTests(TestCase): ...@@ -348,7 +354,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', '^text') search_fields = ('title', '^text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=b') request = factory.get('/', {'search': 'b'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
...@@ -396,7 +402,7 @@ class OrderingFilterTests(TestCase): ...@@ -396,7 +402,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=text') request = factory.get('/', {'ordering': 'text'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
...@@ -415,7 +421,7 @@ class OrderingFilterTests(TestCase): ...@@ -415,7 +421,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=-text') request = factory.get('/', {'ordering': '-text'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
...@@ -434,7 +440,7 @@ class OrderingFilterTests(TestCase): ...@@ -434,7 +440,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=foobar') request = factory.get('/', {'ordering': 'foobar'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
...@@ -503,7 +509,7 @@ class OrderingFilterTests(TestCase): ...@@ -503,7 +509,7 @@ class OrderingFilterTests(TestCase):
models.Count("relateds")) models.Count("relateds"))
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=relateds__count') request = factory.get('/', {'ordering': 'relateds__count'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
...@@ -566,7 +572,7 @@ class SensitiveOrderingFilterTests(TestCase): ...@@ -566,7 +572,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=-username') request = factory.get('/', {'ordering': '-username'})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:
...@@ -596,7 +602,7 @@ class SensitiveOrderingFilterTests(TestCase): ...@@ -596,7 +602,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=password') request = factory.get('/', {'ordering': 'password'})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:
......
...@@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK ...@@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import python_2_unicode_compatible
@python_2_unicode_compatible
class Tag(models.Model): class Tag(models.Model):
""" """
Tags have a descriptive slug, and are attached to an arbitrary object. Tags have a descriptive slug, and are attached to an arbitrary object.
...@@ -15,10 +17,11 @@ class Tag(models.Model): ...@@ -15,10 +17,11 @@ class Tag(models.Model):
object_id = models.PositiveIntegerField() object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id') tagged_item = GenericForeignKey('content_type', 'object_id')
def __unicode__(self): def __str__(self):
return self.tag return self.tag
@python_2_unicode_compatible
class Bookmark(models.Model): class Bookmark(models.Model):
""" """
A URL bookmark that may have multiple tags attached. A URL bookmark that may have multiple tags attached.
...@@ -26,10 +29,11 @@ class Bookmark(models.Model): ...@@ -26,10 +29,11 @@ class Bookmark(models.Model):
url = models.URLField() url = models.URLField()
tags = GenericRelation(Tag) tags = GenericRelation(Tag)
def __unicode__(self): def __str__(self):
return 'Bookmark: %s' % self.url return 'Bookmark: %s' % self.url
@python_2_unicode_compatible
class Note(models.Model): class Note(models.Model):
""" """
A textual note that may have multiple tags attached. A textual note that may have multiple tags attached.
...@@ -37,7 +41,7 @@ class Note(models.Model): ...@@ -37,7 +41,7 @@ class Note(models.Model):
text = models.TextField() text = models.TextField()
tags = GenericRelation(Tag) tags = GenericRelation(Tag)
def __unicode__(self): def __str__(self):
return 'Note: %s' % self.text return 'Note: %s' % self.text
......
...@@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase): ...@@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase):
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name): def get_template(template_name, dirs=None):
if template_name == 'example.html': if template_name == 'example.html':
return Template("example: {{ object }}") return Template("example: {{ object }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
...@@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase): ...@@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self): def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404: Not found")) self.assertTrue(response.content in (
six.b("404: Not found"), six.b("404 Not Found")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self): def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403: Permission denied")) self.assertTrue(response.content in (
six.b("403: Permission denied"), six.b("403 Forbidden")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
...@@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers ...@@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel
from .models import FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
# Helper function to split arguments out of an url
def split_arguments_from_url(url):
if '?' not in url:
return url
class FilterableItem(models.Model): path, args = url.split('?')
text = models.CharField(max_length=100) args = dict(r.split('=') for r in args.split('&'))
decimal = models.DecimalField(max_digits=4, decimal_places=2) return path, args
date = models.DateField()
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
...@@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase): ...@@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase): ...@@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
EXPECTED_NUM_QUERIES = 2 EXPECTED_NUM_QUERIES = 2
request = factory.get('/?decimal=15.20') request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None) self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous']) request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = BasicFilterFieldsRootView.as_view() view = BasicFilterFieldsRootView.as_view()
request = factory.get('/?decimal=15.20') request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ...@@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None) self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous']) request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase): ...@@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase):
""" """
If paginate_by_param is set, the new kwarg should limit per view requests. If paginate_by_param is set, the new kwarg should limit per view requests.
""" """
request = factory.get('/?page_size=5') request = factory.get('/', {'page_size': 5})
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5]) self.assertEqual(response.data['results'], self.data[:5])
...@@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase): ...@@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase):
""" """
If max_paginate_by is set, it should limit page size for the view. If max_paginate_by is set, it should limit page size for the view.
""" """
request = factory.get('/?page_size=10') request = factory.get('/', data={'page_size': 10})
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5]) self.assertEqual(response.data['results'], self.data[:5])
......
...@@ -3,9 +3,7 @@ from django.db import models ...@@ -3,9 +3,7 @@ from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from .models import OneToOneTarget
class OneToOneTarget(models.Model):
name = models.CharField(max_length=100)
class OneToOneSource(models.Model): class OneToOneSource(models.Model):
......
...@@ -613,6 +613,10 @@ class CacheRenderTest(TestCase): ...@@ -613,6 +613,10 @@ class CacheRenderTest(TestCase):
method = getattr(self.client, http_method) method = getattr(self.client, http_method)
resp = method(url) resp = method(url)
del resp.client, resp.request del resp.client, resp.request
try:
del resp.wsgi_request
except AttributeError:
pass
return resp return resp
def test_obj_pickling(self): def test_obj_pickling(self):
......
...@@ -14,6 +14,26 @@ import datetime ...@@ -14,6 +14,26 @@ import datetime
import pickle import pickle
class AMOAFModel(RESTFrameworkModel):
char_field = models.CharField(max_length=1024, blank=True)
comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
email_field = models.EmailField(max_length=1024, blank=True)
file_field = models.FileField(upload_to='test', max_length=1024, blank=True)
image_field = models.ImageField(upload_to='test', max_length=1024, blank=True)
slug_field = models.SlugField(max_length=1024, blank=True)
url_field = models.URLField(max_length=1024, blank=True)
class DVOAFModel(RESTFrameworkModel):
positive_integer_field = models.PositiveIntegerField(blank=True)
positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
email_field = models.EmailField(blank=True)
file_field = models.FileField(upload_to='test', blank=True)
image_field = models.ImageField(upload_to='test', blank=True)
slug_field = models.SlugField(blank=True)
url_field = models.URLField(blank=True)
class SubComment(object): class SubComment(object):
def __init__(self, sub_comment): def __init__(self, sub_comment):
self.sub_comment = sub_comment self.sub_comment = sub_comment
...@@ -1496,15 +1516,6 @@ class ManyFieldHelpTextTest(TestCase): ...@@ -1496,15 +1516,6 @@ class ManyFieldHelpTextTest(TestCase):
class AttributeMappingOnAutogeneratedFieldsTests(TestCase): class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
def setUp(self): def setUp(self):
class AMOAFModel(RESTFrameworkModel):
char_field = models.CharField(max_length=1024, blank=True)
comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
email_field = models.EmailField(max_length=1024, blank=True)
file_field = models.FileField(max_length=1024, blank=True)
image_field = models.ImageField(max_length=1024, blank=True)
slug_field = models.SlugField(max_length=1024, blank=True)
url_field = models.URLField(max_length=1024, blank=True)
class AMOAFSerializer(serializers.ModelSerializer): class AMOAFSerializer(serializers.ModelSerializer):
class Meta: class Meta:
...@@ -1577,14 +1588,6 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase): ...@@ -1577,14 +1588,6 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
class DefaultValuesOnAutogeneratedFieldsTests(TestCase): class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
def setUp(self): def setUp(self):
class DVOAFModel(RESTFrameworkModel):
positive_integer_field = models.PositiveIntegerField(blank=True)
positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
email_field = models.EmailField(blank=True)
file_field = models.FileField(blank=True)
image_field = models.ImageField(blank=True)
slug_field = models.SlugField(blank=True)
url_field = models.URLField(blank=True)
class DVOAFSerializer(serializers.ModelSerializer): class DVOAFSerializer(serializers.ModelSerializer):
class Meta: class Meta:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment