utils.py 8.67 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
"""
Utility classes for testing django applications.

:py:class:`CacheIsolationMixin`
    A mixin helping to write tests which are isolated from cached data.

:py:class:`CacheIsolationTestCase`
    A TestCase baseclass that has per-test isolated caches.
"""

import copy
12
import re
13
from unittest import skipUnless
14

15
import crum
16
from django import db
17 18
from django.conf import settings
from django.contrib import sites
19
from django.contrib.auth.models import AnonymousUser
20
from django.core.cache import caches
21
from django.db import DEFAULT_DB_ALIAS, connections
22
from django.test import RequestFactory, TestCase, override_settings
23
from django.test.utils import CaptureQueriesContext
24
from nose.plugins import Plugin
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
from request_cache.middleware import RequestCache


class CacheIsolationMixin(object):
    """
    This class can be used to enable specific django caches for
    specific the TestCase that it's mixed into.

    Usage:

    Use the ENABLED_CACHES to list the names of caches that should
    be enabled in the context of this TestCase. These caches will
    use a loc_mem_cache with the default settings.

    Set the class variable CACHES to explicitly specify the cache settings
    that should be overridden. This class will insert those values into
    django.conf.settings, and will reset all named caches before each
    test.

    If both CACHES and ENABLED_CACHES are not None, raises an error.
    """

    CACHES = None
    ENABLED_CACHES = None
49 50 51

    __settings_overrides = []
    __old_settings = []
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

    @classmethod
    def start_cache_isolation(cls):
        """
        Start cache isolation by overriding the settings.CACHES and
        flushing the cache.
        """
        cache_settings = None
        if cls.CACHES is not None and cls.ENABLED_CACHES is not None:
            raise Exception(
                "Use either CACHES or ENABLED_CACHES, but not both"
            )

        if cls.CACHES is not None:
            cache_settings = cls.CACHES
        elif cls.ENABLED_CACHES is not None:
            cache_settings = {
                'default': {
                    'BACKEND': 'django.core.cache.backends.dummy.DummyCache',
                }
            }

            cache_settings.update({
                cache_name: {
                    'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
                    'LOCATION': cache_name,
                    'KEY_FUNCTION': 'util.memcache.safe_key',
                } for cache_name in cls.ENABLED_CACHES
            })

        if cache_settings is None:
            return

85 86 87 88 89 90
        cls.__old_settings.append(copy.deepcopy(settings.CACHES))
        override = override_settings(CACHES=cache_settings)
        override.__enter__()
        cls.__settings_overrides.append(override)

        assert settings.CACHES == cache_settings
91 92 93 94 95 96 97 98 99 100 101 102 103

        # Start with empty caches
        cls.clear_caches()

    @classmethod
    def end_cache_isolation(cls):
        """
        End cache isolation by flushing the cache and then returning
        settings.CACHES to its original state.
        """
        # Make sure that cache contents don't leak out after the isolation is ended
        cls.clear_caches()

104 105 106
        if cls.__settings_overrides:
            cls.__settings_overrides.pop().__exit__(None, None, None)
            assert settings.CACHES == cls.__old_settings.pop()
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

    @classmethod
    def clear_caches(cls):
        """
        Clear all of the caches defined in settings.CACHES.
        """
        # N.B. As of 2016-04-20, Django won't return any caches
        # from django.core.cache.caches.all() that haven't been
        # accessed using caches[name] previously, so we loop
        # over our list of overridden caches, instead.
        for cache in settings.CACHES:
            caches[cache].clear()

        # The sites framework caches in a module-level dictionary.
        # Clear that.
        sites.models.SITE_CACHE.clear()

        RequestCache.clear_request_cache()


class CacheIsolationTestCase(CacheIsolationMixin, TestCase):
    """
    A TestCase that isolates caches (as described in
    :py:class:`CacheIsolationMixin`) at class setup, and flushes the cache
    between every test.
    """
    @classmethod
    def setUpClass(cls):
        super(CacheIsolationTestCase, cls).setUpClass()
        cls.start_cache_isolation()

    @classmethod
    def tearDownClass(cls):
        cls.end_cache_isolation()
        super(CacheIsolationTestCase, cls).tearDownClass()

    def setUp(self):
        super(CacheIsolationTestCase, self).setUp()

        self.clear_caches()
        self.addCleanup(self.clear_caches)
148 149


150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
class _AssertNumQueriesContext(CaptureQueriesContext):
    """
    This is a copy of Django's internal class of the same name, with the
    addition of being able to provide a table_blacklist used to filter queries
    before comparing the count.
    """
    def __init__(self, test_case, num, connection, table_blacklist=None):
        """
        Same as Django's _AssertNumQueriesContext __init__, with the addition of
        the following argument:
            table_blacklist (List): A list of table names to filter out of the
                set of queries that get counted.
        """
        self.test_case = test_case
        self.num = num
        self.table_blacklist = table_blacklist
        super(_AssertNumQueriesContext, self).__init__(connection)

    def __exit__(self, exc_type, exc_value, traceback):
        def is_unfiltered_query(query):
            """
            Returns True if the query does not contain a blacklisted table, and
            False otherwise.

            Note: This is a simple naive implementation that makes no attempt
            to parse the query.
            """
            if self.table_blacklist:
                for table in self.table_blacklist:
                    # SQL contains the following format for columns:
                    # "table_name"."column_name".  The regex ensures there is no
                    # "." before the name to avoid matching columns.
                    if re.search(r'[^.]"{}"'.format(table), query['sql']):
                        return False
            return True

        super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback)
        if exc_type is not None:
            return
        filtered_queries = [query for query in self.captured_queries if is_unfiltered_query(query)]
        executed = len(filtered_queries)
        self.test_case.assertEqual(
            executed, self.num,
            "%d queries executed, %d expected\nCaptured queries were:\n%s" % (
                executed, self.num,
                '\n'.join(
                    query['sql'] for query in filtered_queries
                )
            )
        )


class FilteredQueryCountMixin(object):
    """
    Mixin to add to any subclass of Django's TestCase that replaces
    assertNumQueries with one that accepts a blacklist of tables to filter out
    of the count.
    """
    def assertNumQueries(self, num, func=None, table_blacklist=None, *args, **kwargs):
        """
        Used to replace Django's assertNumQueries with the same capability, with
        the addition of the following argument:
            table_blacklist (List): A list of table names to filter out of the
                set of queries that get counted.
        """
        using = kwargs.pop("using", DEFAULT_DB_ALIAS)
        conn = connections[using]

        context = _AssertNumQueriesContext(self, num, conn, table_blacklist=table_blacklist)
        if func is None:
            return context

        with context:
            func(*args, **kwargs)


226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
class NoseDatabaseIsolation(Plugin):
    """
    nosetest plugin that resets django databases before any tests begin.

    Used to make sure that tests running in multi processes aren't sharing
    a database connection.
    """
    name = "database-isolation"

    def begin(self):
        """
        Before any tests start, reset all django database connections.
        """
        for db_ in db.connections.all():
            db_.close()
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255


def get_mock_request(user=None):
    """
    Create a request object for the user, if specified.
    """
    request = RequestFactory().get('/')
    if user is not None:
        request.user = user
    else:
        request.user = AnonymousUser()
    request.is_secure = lambda: True
    request.get_host = lambda: "edx.org"
    crum.set_current_request(request)
    return request
256 257 258 259 260 261 262 263 264 265 266 267 268 269


def skip_unless_cms(func):
    """
    Only run the decorated test in the CMS test suite
    """
    return skipUnless(settings.ROOT_URLCONF == 'cms.urls', 'Test only valid in CMS')(func)


def skip_unless_lms(func):
    """
    Only run the decorated test in the LMS test suite
    """
    return skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in LMS')(func)