tests.py 9.2 KB
Newer Older
1
"""Tests for License package"""
2
import logging
3 4
import json

5 6 7
from uuid import uuid4
from random import shuffle
from tempfile import NamedTemporaryFile
8 9
import factory
from factory.django import DjangoModelFactory
10 11

from django.test import TestCase
12
from django.test.client import Client
13
from django.core.management import call_command
14
from django.core.urlresolvers import reverse
15
from nose.tools import assert_true  # pylint: disable=no-name-in-module
16

17
from licenses.models import CourseSoftware, UserLicense
18 19

from student.tests.factories import UserFactory
20 21
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
22

23
COURSE_1 = 'edX/toy/2012_Fall'
24 25 26 27

SOFTWARE_1 = 'matlab'
SOFTWARE_2 = 'stata'

28 29
SERIAL_1 = '123456abcde'

30 31 32
log = logging.getLogger(__name__)


33
class CourseSoftwareFactory(DjangoModelFactory):
34 35 36 37 38 39 40 41 42
    '''Factory for generating CourseSoftware objects in database'''
    FACTORY_FOR = CourseSoftware

    name = SOFTWARE_1
    full_name = SOFTWARE_1
    url = SOFTWARE_1
    course_id = COURSE_1


43
class UserLicenseFactory(DjangoModelFactory):
44 45 46 47 48 49 50 51
    '''
    Factory for generating UserLicense objects in database

    By default, the user assigned is null, indicating that the
    serial number has not yet been assigned.
    '''
    FACTORY_FOR = UserLicense

52
    user = None
53
    software = factory.SubFactory(CourseSoftwareFactory)
54 55 56
    serial = SERIAL_1


57
class LicenseTestCase(TestCase):
58 59 60
    '''Tests for licenses.views'''
    def setUp(self):
        '''creates a user and logs in'''
61 62

        super(LicenseTestCase, self).setUp()
63 64 65 66 67
        # self.setup_viewtest_user()
        self.user = UserFactory(username='test',
                                email='test@edx.org', password='test_password')
        self.client = Client()
        assert_true(self.client.login(username='test', password='test_password'))
68 69 70
        self.software = CourseSoftwareFactory()

    def test_get_license(self):
71
        UserLicenseFactory(user=self.user, software=self.software)
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('error' in json_returned)
        self.assertTrue('serial' in json_returned)
        self.assertEquals(json_returned['serial'], SERIAL_1)

    def test_get_nonexistent_license(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('serial' in json_returned)
        self.assertTrue('error' in json_returned)

    def test_create_nonexistent_license(self):
        '''Should not assign a license to an unlicensed user when none are available'''
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'true'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('serial' in json_returned)
        self.assertTrue('error' in json_returned)

    def test_create_license(self):
        '''Should assign a license to an unlicensed user if one is unassigned'''
        # create an unassigned license
        UserLicenseFactory(software=self.software)
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'true'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('error' in json_returned)
        self.assertTrue('serial' in json_returned)
        self.assertEquals(json_returned['serial'], SERIAL_1)

    def test_get_license_from_wrong_course(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format('some/other/course'))
        self.assertEqual(404, response.status_code)

    def test_get_license_from_non_ajax(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(404, response.status_code)

    def test_get_license_without_software(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(404, response.status_code)

    def test_get_license_without_login(self):
138
        self.client.logout()
139 140 141 142 143 144 145 146
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        # if we're not logged in, we should be referred to the login page
        self.assertEqual(302, response.status_code)


147
class CommandTest(ModuleStoreTestCase):
148
    '''Test management command for importing serial numbers'''
149
    def setUp(self):
150 151
        super(CommandTest, self).setUp()

152 153
        course = CourseFactory.create()
        self.course_id = course.id
154

155 156 157 158 159
    def test_import_serial_numbers(self):
        size = 20

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
160
            args = [self.course_id.to_deprecated_string(), SOFTWARE_1, temp_file.name]
161 162 163 164
            call_command('import_serial_numbers', *args)

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_2))
        with generate_serials_file(size) as temp_file:
165
            args = [self.course_id.to_deprecated_string(), SOFTWARE_2, temp_file.name]
166 167 168 169 170 171 172 173 174 175 176 177
            call_command('import_serial_numbers', *args)

        log.debug('There should be only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

        log.debug('We added two sets of {0} serials'.format(size))
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(2 * size, licenses_count)

        log.debug('Adding more serial numbers to {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
178
            args = [self.course_id.to_deprecated_string(), SOFTWARE_1, temp_file.name]
179 180 181 182 183 184
            call_command('import_serial_numbers', *args)

        log.debug('There should be still only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

185 186 187 188
        log.debug(
            "Now we should have 3 sets of %s serials",
            size,
        )
189 190 191
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(3 * size, licenses_count)

192
        software = CourseSoftware.objects.get(pk=1)
193

194
        lics = UserLicense.objects.filter(software=software)[:size]
195 196 197 198 199 200
        known_serials = list(l.serial for l in lics)
        known_serials.extend(generate_serials(10))

        shuffle(known_serials)

        log.debug('Adding some new and old serials to {0}'.format(SOFTWARE_1))
201 202 203
        with NamedTemporaryFile() as tmpfile:
            tmpfile.write('\n'.join(known_serials))
            tmpfile.flush()
204
            args = [self.course_id.to_deprecated_string(), SOFTWARE_1, tmpfile.name]
205 206 207
            call_command('import_serial_numbers', *args)

        log.debug('Check if we added only the new ones')
208
        licenses_count = UserLicense.objects.filter(software=software).count()
209 210 211 212
        self.assertEqual((2 * size) + 10, licenses_count)


def generate_serials(size=20):
213
    '''generate a list of serial numbers'''
214 215 216 217
    return [str(uuid4()) for _ in range(size)]


def generate_serials_file(size=20):
218
    '''output list of generated serial numbers to a temp file'''
219 220 221 222 223 224 225
    serials = generate_serials(size)

    temp_file = NamedTemporaryFile()
    temp_file.write('\n'.join(serials))
    temp_file.flush()

    return temp_file