Commit 65bca59e by Tom Christie

Reload api_settings when using Django's 'override_settings'

parent fc70c086
...@@ -18,6 +18,7 @@ REST framework settings, checking for user settings first, then falling ...@@ -18,6 +18,7 @@ REST framework settings, checking for user settings first, then falling
back to the defaults. back to the defaults.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test.signals import setting_changed
from django.conf import settings from django.conf import settings
from django.utils import importlib, six from django.utils import importlib, six
from rest_framework import ISO_8601 from rest_framework import ISO_8601
...@@ -198,3 +199,13 @@ class APISettings(object): ...@@ -198,3 +199,13 @@ class APISettings(object):
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
def reload_api_settings(*args, **kwargs):
global api_settings
setting, value = kwargs['setting'], kwargs['value']
if setting == 'REST_FRAMEWORK':
api_settings = APISettings(value, DEFAULTS, IMPORT_STRINGS)
setting_changed.connect(reload_api_settings)
...@@ -5,13 +5,15 @@ from django.db import models ...@@ -5,13 +5,15 @@ from django.db import models
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.test.utils import override_settings
from django.utils import unittest from django.utils import unittest
from django.utils.dateparse import parse_date from django.utils.dateparse import parse_date
from django.utils.six.moves import reload_module
from rest_framework import generics, serializers, status, filters from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from .models import BaseFilterableItem, FilterableItem, BasicModel from .models import BaseFilterableItem, FilterableItem, BasicModel
from .utils import temporary_setting
factory = APIRequestFactory() factory = APIRequestFactory()
...@@ -404,7 +406,9 @@ class SearchFilterTests(TestCase): ...@@ -404,7 +406,9 @@ class SearchFilterTests(TestCase):
) )
def test_search_with_nonstandard_search_param(self): def test_search_with_nonstandard_search_param(self):
with temporary_setting('SEARCH_PARAM', 'query', module=filters): with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
reload_module(filters)
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
...@@ -422,6 +426,8 @@ class SearchFilterTests(TestCase): ...@@ -422,6 +426,8 @@ class SearchFilterTests(TestCase):
] ]
) )
reload_module(filters)
class OrderingFilterModel(models.Model): class OrderingFilterModel(models.Model):
title = models.CharField(max_length=20) title = models.CharField(max_length=20)
...@@ -642,7 +648,9 @@ class OrderingFilterTests(TestCase): ...@@ -642,7 +648,9 @@ class OrderingFilterTests(TestCase):
) )
def test_ordering_with_nonstandard_ordering_param(self): def test_ordering_with_nonstandard_ordering_param(self):
with temporary_setting('ORDERING_PARAM', 'order', filters): with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
reload_module(filters)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
...@@ -662,6 +670,8 @@ class OrderingFilterTests(TestCase): ...@@ -662,6 +670,8 @@ class OrderingFilterTests(TestCase):
] ]
) )
reload_module(filters)
class SensitiveOrderingFilterModel(models.Model): class SensitiveOrderingFilterModel(models.Model):
username = models.CharField(max_length=20) username = models.CharField(max_length=20)
......
from contextlib import contextmanager
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.core.urlresolvers import NoReverseMatch from django.core.urlresolvers import NoReverseMatch
from django.utils import six
from rest_framework.settings import api_settings
@contextmanager
def temporary_setting(setting, value, module=None):
"""
Temporarily change value of setting for test.
Optionally reload given module, useful when module uses value of setting on
import.
"""
original_value = getattr(api_settings, setting)
setattr(api_settings, setting, value)
if module is not None:
six.moves.reload_module(module)
yield
setattr(api_settings, setting, original_value)
if module is not None:
six.moves.reload_module(module)
class MockObject(object): class MockObject(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment