"""Tests for License package"""
import logging
import json

from uuid import uuid4
from random import shuffle
from tempfile import NamedTemporaryFile
import factory
from factory.django import DjangoModelFactory

from django.test import TestCase
from django.test.client import Client
from django.test.utils import override_settings
from django.core.management import call_command
from django.core.urlresolvers import reverse
from nose.tools import assert_true  # pylint: disable=no-name-in-module

from xmodule.modulestore.tests.django_utils import TEST_DATA_MOCK_MODULESTORE
from licenses.models import CourseSoftware, UserLicense

from student.tests.factories import UserFactory
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase

COURSE_1 = 'edX/toy/2012_Fall'

SOFTWARE_1 = 'matlab'
SOFTWARE_2 = 'stata'

SERIAL_1 = '123456abcde'

log = logging.getLogger(__name__)


class CourseSoftwareFactory(DjangoModelFactory):
    '''Factory for generating CourseSoftware objects in database'''
    FACTORY_FOR = CourseSoftware

    name = SOFTWARE_1
    full_name = SOFTWARE_1
    url = SOFTWARE_1
    course_id = COURSE_1


class UserLicenseFactory(DjangoModelFactory):
    '''
    Factory for generating UserLicense objects in database

    By default, the user assigned is null, indicating that the
    serial number has not yet been assigned.
    '''
    FACTORY_FOR = UserLicense

    user = None
    software = factory.SubFactory(CourseSoftwareFactory)
    serial = SERIAL_1


class LicenseTestCase(TestCase):
    '''Tests for licenses.views'''
    def setUp(self):
        '''creates a user and logs in'''
        # self.setup_viewtest_user()
        self.user = UserFactory(username='test',
                                email='test@edx.org', password='test_password')
        self.client = Client()
        assert_true(self.client.login(username='test', password='test_password'))
        self.software = CourseSoftwareFactory()

    def test_get_license(self):
        UserLicenseFactory(user=self.user, software=self.software)
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('error' in json_returned)
        self.assertTrue('serial' in json_returned)
        self.assertEquals(json_returned['serial'], SERIAL_1)

    def test_get_nonexistent_license(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('serial' in json_returned)
        self.assertTrue('error' in json_returned)

    def test_create_nonexistent_license(self):
        '''Should not assign a license to an unlicensed user when none are available'''
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'true'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('serial' in json_returned)
        self.assertTrue('error' in json_returned)

    def test_create_license(self):
        '''Should assign a license to an unlicensed user if one is unassigned'''
        # create an unassigned license
        UserLicenseFactory(software=self.software)
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'true'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(200, response.status_code)
        json_returned = json.loads(response.content)
        self.assertFalse('error' in json_returned)
        self.assertTrue('serial' in json_returned)
        self.assertEquals(json_returned['serial'], SERIAL_1)

    def test_get_license_from_wrong_course(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format('some/other/course'))
        self.assertEqual(404, response.status_code)

    def test_get_license_from_non_ajax(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(404, response.status_code)

    def test_get_license_without_software(self):
        response = self.client.post(reverse('user_software_license'),
                                    {'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        self.assertEqual(404, response.status_code)

    def test_get_license_without_login(self):
        self.client.logout()
        response = self.client.post(reverse('user_software_license'),
                                    {'software': SOFTWARE_1, 'generate': 'false'},
                                    HTTP_X_REQUESTED_WITH='XMLHttpRequest',
                                    HTTP_REFERER='/courses/{0}/some_page'.format(COURSE_1))
        # if we're not logged in, we should be referred to the login page
        self.assertEqual(302, response.status_code)


@override_settings(MODULESTORE=TEST_DATA_MOCK_MODULESTORE)
class CommandTest(ModuleStoreTestCase):
    '''Test management command for importing serial numbers'''
    def setUp(self):
        course = CourseFactory.create()
        self.course_id = course.id

    def test_import_serial_numbers(self):
        size = 20

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
            args = [self.course_id.to_deprecated_string(), SOFTWARE_1, temp_file.name]
            call_command('import_serial_numbers', *args)

        log.debug('Adding one set of serials for {0}'.format(SOFTWARE_2))
        with generate_serials_file(size) as temp_file:
            args = [self.course_id.to_deprecated_string(), SOFTWARE_2, temp_file.name]
            call_command('import_serial_numbers', *args)

        log.debug('There should be only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

        log.debug('We added two sets of {0} serials'.format(size))
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(2 * size, licenses_count)

        log.debug('Adding more serial numbers to {0}'.format(SOFTWARE_1))
        with generate_serials_file(size) as temp_file:
            args = [self.course_id.to_deprecated_string(), SOFTWARE_1, temp_file.name]
            call_command('import_serial_numbers', *args)

        log.debug('There should be still only 2 course-software entries')
        software_count = CourseSoftware.objects.all().count()
        self.assertEqual(2, software_count)

        log.debug('Now we should have 3 sets of 20 serials'.format(size))
        licenses_count = UserLicense.objects.all().count()
        self.assertEqual(3 * size, licenses_count)

        software = CourseSoftware.objects.get(pk=1)

        lics = UserLicense.objects.filter(software=software)[:size]
        known_serials = list(l.serial for l in lics)
        known_serials.extend(generate_serials(10))

        shuffle(known_serials)

        log.debug('Adding some new and old serials to {0}'.format(SOFTWARE_1))
        with NamedTemporaryFile() as tmpfile:
            tmpfile.write('\n'.join(known_serials))
            tmpfile.flush()
            args = [self.course_id.to_deprecated_string(), SOFTWARE_1, tmpfile.name]
            call_command('import_serial_numbers', *args)

        log.debug('Check if we added only the new ones')
        licenses_count = UserLicense.objects.filter(software=software).count()
        self.assertEqual((2 * size) + 10, licenses_count)


def generate_serials(size=20):
    '''generate a list of serial numbers'''
    return [str(uuid4()) for _ in range(size)]


def generate_serials_file(size=20):
    '''output list of generated serial numbers to a temp file'''
    serials = generate_serials(size)

    temp_file = NamedTemporaryFile()
    temp_file.write('\n'.join(serials))
    temp_file.flush()

    return temp_file