tests.py 9.14 KB
Newer Older
1
"""Tests for License package"""
2
import logging
3 4
import json

5 6 7
from uuid import uuid4
from random import shuffle
from tempfile import NamedTemporaryFile
8
from factory import DjangoModelFactory, SubFactory
9 10

from django.test import TestCase
11
from django.test.client import Client
12
from django.test.utils import override_settings
13
from django.core.management import call_command
14
from django.core.urlresolvers import reverse
15
from nose.tools import assert_true  # pylint: disable=E0611
16

17
from courseware.tests.modulestore_config import TEST_DATA_MIXED_MODULESTORE
18
from licenses.models import CourseSoftware, UserLicense
19 20

from student.tests.factories import UserFactory
21 22
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
23

24
COURSE_1 = 'edX/toy/2012_Fall'
25 26 27 28

SOFTWARE_1 = 'matlab'
SOFTWARE_2 = 'stata'

29 30
SERIAL_1 = '123456abcde'

31 32 33
log = logging.getLogger(__name__)


34
class CourseSoftwareFactory(DjangoModelFactory):
35 36 37 38 39 40 41 42 43
    '''Factory for generating CourseSoftware objects in database'''
    FACTORY_FOR = CourseSoftware

    name = SOFTWARE_1
    full_name = SOFTWARE_1
    url = SOFTWARE_1
    course_id = COURSE_1


44
class UserLicenseFactory(DjangoModelFactory):
45 46 47 48 49 50 51 52
    '''
    Factory for generating UserLicense objects in database

    By default, the user assigned is null, indicating that the
    serial number has not yet been assigned.
    '''
    FACTORY_FOR = UserLicense

53
    user = None
54 55 56 57
    software = SubFactory(CourseSoftwareFactory)
    serial = SERIAL_1


58
class LicenseTestCase(TestCase):
59 60 61
    '''Tests for licenses.views'''
    def setUp(self):
        '''creates a user and logs in'''
62 63 64 65 66
        # self.setup_viewtest_user()
        self.user = UserFactory(username='test',
                                email='test@edx.org', password='test_password')
        self.client = Client()
        assert_true(self.client.login(username='test', password='test_password'))
67 68 69
        self.software = CourseSoftwareFactory()

    def test_get_license(self):
70
        UserLicenseFactory(user=self.user, software=self.software)
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('error' in json_returned)
        self.assertTrue('serial' in json_returned)
        self.assertEquals(json_returned['serial'], SERIAL_1)

    def test_get_nonexistent_license(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('serial' in json_returned)
        self.assertTrue('error' in json_returned)

    def test_create_nonexistent_license(self):
        '''Should not assign a license to an unlicensed user when none are available'''
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'true'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('serial' in json_returned)
        self.assertTrue('error' in json_returned)

    def test_create_license(self):
        '''Should assign a license to an unlicensed user if one is unassigned'''
        # create an unassigned license
        UserLicenseFactory(software=self.software)
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'true'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('error' in json_returned)
        self.assertTrue('serial' in json_returned)
        self.assertEquals(json_returned['serial'], SERIAL_1)

    def test_get_license_from_wrong_course(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format('some/other/course'))
        self.assertEqual(404, response.status_code)

    def test_get_license_from_non_ajax(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(404, response.status_code)

    def test_get_license_without_software(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(404, response.status_code)

    def test_get_license_without_login(self):
137
        self.client.logout()
138 139 140 141 142 143 144 145
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        # if we're not logged in, we should be referred to the login page
        self.assertEqual(302, response.status_code)


146
@override_settings(MODULESTORE=TEST_DATA_MIXED_MODULESTORE)
147
class CommandTest(ModuleStoreTestCase):
148
    '''Test management command for importing serial numbers'''
149 150 151
    def setUp(self):
        course = CourseFactory.create()
        self.course_id = course.id
152

153 154 155 156 157
    def test_import_serial_numbers(self):
        size = 20

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
158
            args = [self.course_id, SOFTWARE_1, temp_file.name]
159 160 161 162
            call_command('import_serial_numbers', *args)

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_2))
        with generate_serials_file(size) as temp_file:
163
            args = [self.course_id, SOFTWARE_2, temp_file.name]
164 165 166 167 168 169 170 171 172 173 174 175
            call_command('import_serial_numbers', *args)

        log.debug('There should be only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

        log.debug('We added two sets of {0} serials'.format(size))
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(2 * size, licenses_count)

        log.debug('Adding more serial numbers to {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
176
            args = [self.course_id, SOFTWARE_1, temp_file.name]
177 178 179 180 181 182 183 184 185 186
            call_command('import_serial_numbers', *args)

        log.debug('There should be still only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

        log.debug('Now we should have 3 sets of 20 serials'.format(size))
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(3 * size, licenses_count)

187
        software = CourseSoftware.objects.get(pk=1)
188

189
        lics = UserLicense.objects.filter(software=software)[:size]
190 191 192 193 194 195
        known_serials = list(l.serial for l in lics)
        known_serials.extend(generate_serials(10))

        shuffle(known_serials)

        log.debug('Adding some new and old serials to {0}'.format(SOFTWARE_1))
196 197 198
        with NamedTemporaryFile() as tmpfile:
            tmpfile.write('\n'.join(known_serials))
            tmpfile.flush()
199
            args = [self.course_id, SOFTWARE_1, tmpfile.name]
200 201 202
            call_command('import_serial_numbers', *args)

        log.debug('Check if we added only the new ones')
203
        licenses_count = UserLicense.objects.filter(software=software).count()
204 205 206 207
        self.assertEqual((2 * size) + 10, licenses_count)


def generate_serials(size=20):
208
    '''generate a list of serial numbers'''
209 210 211 212
    return [str(uuid4()) for _ in range(size)]


def generate_serials_file(size=20):
213
    '''output list of generated serial numbers to a temp file'''
214 215 216 217 218 219 220
    serials = generate_serials(size)

    temp_file = NamedTemporaryFile()
    temp_file.write('\n'.join(serials))
    temp_file.flush()

    return temp_file