test_db.py 5.46 KB
Newer Older
1 2 3 4 5 6 7 8 9
"""Tests for util.db module."""

import ddt
import threading
import time
import unittest

from django.contrib.auth.models import User
from django.db import connection, IntegrityError
10
from django.db.transaction import atomic, TransactionManagementError
11
from django.test import TestCase, TransactionTestCase
12

13
from util.db import commit_on_success, generate_int_id, outer_atomic
14 15 16


@ddt.ddt
17
class TransactionManagersTestCase(TransactionTestCase):
18
    """
19
    Tests commit_on_success and outer_atomic.
20 21 22

    Note: This TestCase only works with MySQL.

23
    To test do: "./manage.py lms --settings=test_with_mysql test util.tests.test_db"
24 25 26
    """

    @ddt.data(
27 28 29 30
        (outer_atomic(), IntegrityError, None, True),
        (outer_atomic(read_committed=True), type(None), False, True),
        (commit_on_success(), IntegrityError, None, True),
        (commit_on_success(read_committed=True), type(None), False, True),
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    )
    @ddt.unpack
    def test_concurrent_requests(self, transaction_decorator, exception_class, created_in_1, created_in_2):
        """
        Test that when isolation level is set to READ COMMITTED get_or_create()
        for the same row in concurrent requests does not raise an IntegrityError.
        """

        if connection.vendor != 'mysql':
            raise unittest.SkipTest('Only works on MySQL.')

        class RequestThread(threading.Thread):
            """ A thread which runs a dummy view."""
            def __init__(self, delay, **kwargs):
                super(RequestThread, self).__init__(**kwargs)
                self.delay = delay
                self.status = {}

            @transaction_decorator
            def run(self):
                """A dummy view."""
                try:
                    try:
                        User.objects.get(username='student', email='student@edx.org')
                    except User.DoesNotExist:
                        pass
                    else:
                        raise AssertionError('Did not raise User.DoesNotExist.')

                    if self.delay > 0:
                        time.sleep(self.delay)

                    __, created = User.objects.get_or_create(username='student', email='student@edx.org')
                except Exception as exception:  # pylint: disable=broad-except
                    self.status['exception'] = exception
                else:
                    self.status['created'] = created

        thread1 = RequestThread(delay=1)
        thread2 = RequestThread(delay=0)

        thread1.start()
        thread2.start()
        thread2.join()
        thread1.join()

        self.assertIsInstance(thread1.status.get('exception'), exception_class)
        self.assertEqual(thread1.status.get('created'), created_in_1)

        self.assertIsNone(thread2.status.get('exception'))
        self.assertEqual(thread2.status.get('created'), created_in_2)
82

83 84 85 86 87
    def test_outer_atomic_nesting(self):
        """
        Test that outer_atomic raises an error if it is nested inside
        another atomic.
        """
88 89 90 91 92 93 94 95

        if connection.vendor != 'mysql':
            raise unittest.SkipTest('Only works on MySQL.')

        def do_nothing():
            """Just return."""
            return

96 97 98 99 100 101 102 103 104 105 106
        outer_atomic()(do_nothing)()

        with atomic():
            atomic()(do_nothing)()

        with outer_atomic():
            atomic()(do_nothing)()

        with self.assertRaisesRegexp(TransactionManagementError, 'Cannot be inside an atomic block.'):
            with atomic():
                outer_atomic()(do_nothing)()
107

108 109 110
        with self.assertRaisesRegexp(TransactionManagementError, 'Cannot be inside an atomic block.'):
            with outer_atomic():
                outer_atomic()(do_nothing)()
111

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    def test_commit_on_success_nesting(self):
        """
        Test that commit_on_success raises an error if it is nested inside
        atomic or if the isolation level is changed when it is nested
        inside another commit_on_success.
        """
        # pylint: disable=not-callable

        if connection.vendor != 'mysql':
            raise unittest.SkipTest('Only works on MySQL.')

        def do_nothing():
            """Just return."""
            return

        commit_on_success(read_committed=True)(do_nothing)()

        with self.assertRaisesRegexp(TransactionManagementError, 'Cannot change isolation level when nested.'):
130
            with commit_on_success():
131 132 133 134 135
                commit_on_success(read_committed=True)(do_nothing)()

        with self.assertRaisesRegexp(TransactionManagementError, 'Cannot be inside an atomic block.'):
            with atomic():
                commit_on_success(read_committed=True)(do_nothing)()
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163


@ddt.ddt
class GenerateIntIdTestCase(TestCase):
    """Tests for `generate_int_id`"""
    @ddt.data(10)
    def test_no_used_ids(self, times):
        """
        Verify that we get a random integer within the specified range
        when there are no used ids.
        """
        minimum = 1
        maximum = times
        for i in range(times):
            self.assertIn(generate_int_id(minimum, maximum), range(minimum, maximum + 1))

    @ddt.data(10)
    def test_used_ids(self, times):
        """
        Verify that we get a random integer within the specified range
        but not in a list of used ids.
        """
        minimum = 1
        maximum = times
        used_ids = {2, 4, 6, 8}
        for i in range(times):
            int_id = generate_int_id(minimum, maximum, used_ids)
            self.assertIn(int_id, list(set(range(minimum, maximum + 1)) - used_ids))