import logging
from uuid import uuid4
from random import shuffle
from tempfile import NamedTemporaryFile

from django.test import TestCase
from django.core.management import call_command

from .models import CourseSoftware, UserLicense

COURSE_1 = 'edX/toy/2012_Fall'

SOFTWARE_1 = 'matlab'
SOFTWARE_2 = 'stata'

log = logging.getLogger(__name__)


class CommandTest(TestCase):
    def test_import_serial_numbers(self):
        size = 20

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
            args = [COURSE_1, SOFTWARE_1, temp_file.name]
            call_command('import_serial_numbers', *args)

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_2))
        with generate_serials_file(size) as temp_file:
            args = [COURSE_1, SOFTWARE_2, temp_file.name]
            call_command('import_serial_numbers', *args)

        log.debug('There should be only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

        log.debug('We added two sets of {0} serials'.format(size))
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(2 * size, licenses_count)

        log.debug('Adding more serial numbers to {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
            args = [COURSE_1, SOFTWARE_1, temp_file.name]
            call_command('import_serial_numbers', *args)

        log.debug('There should be still only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

        log.debug('Now we should have 3 sets of 20 serials'.format(size))
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(3 * size, licenses_count)

        cs = CourseSoftware.objects.get(pk=1)

        lics = UserLicense.objects.filter(software=cs)[:size]
        known_serials = list(l.serial for l in lics)
        known_serials.extend(generate_serials(10))

        shuffle(known_serials)

        log.debug('Adding some new and old serials to {0}'.format(SOFTWARE_1))
        with NamedTemporaryFile() as f:
            f.write('\n'.join(known_serials))
            f.flush()
            args = [COURSE_1, SOFTWARE_1, f.name]
            call_command('import_serial_numbers', *args)

        log.debug('Check if we added only the new ones')
        licenses_count = UserLicense.objects.filter(software=cs).count()
        self.assertEqual((2 * size) + 10, licenses_count)


def generate_serials(size=20):
    return [str(uuid4()) for _ in range(size)]


def generate_serials_file(size=20):
    serials = generate_serials(size)

    temp_file = NamedTemporaryFile()
    temp_file.write('\n'.join(serials))
    temp_file.flush()

    return temp_file