"""
Tests for EmbargoMiddleware
"""

from contextlib import contextmanager
import mock
import unittest
import pygeoip
import ddt

from django.conf import settings
from django.test.utils import override_settings
from django.core.cache import cache
from django.db import connection, transaction

from student.tests.factories import UserFactory
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import (
    ModuleStoreTestCase, mixed_store_config
)

from embargo.models import (
    RestrictedCourse, Country, CountryAccessRule,
)

from util.testing import UrlResetMixin
from embargo import api as embargo_api
from embargo.exceptions import InvalidAccessPoint
from mock import patch


# Since we don't need any XML course fixtures, use a modulestore configuration
# that disables the XML modulestore.
MODULESTORE_CONFIG = mixed_store_config(settings.COMMON_TEST_DATA_ROOT, {}, include_xml=False)


@ddt.ddt
@override_settings(MODULESTORE=MODULESTORE_CONFIG)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
@mock.patch.dict(settings.FEATURES, {'EMBARGO': True})
class EmbargoCheckAccessApiTests(ModuleStoreTestCase):
    """Test the embargo API calls to determine whether a user has access. """

    def setUp(self):
        super(EmbargoCheckAccessApiTests, self).setUp()
        self.course = CourseFactory.create()
        self.user = UserFactory.create()
        self.restricted_course = RestrictedCourse.objects.create(course_key=self.course.id)
        Country.objects.create(country='US')
        Country.objects.create(country='IR')
        Country.objects.create(country='CU')

        # Clear the cache to prevent interference between tests
        cache.clear()

    @ddt.data(
        # IP country, profile_country, blacklist, whitelist, allow_access
        ('US', None, [], [], True),
        ('IR', None, ['IR', 'CU'], [], False),
        ('US', 'IR', ['IR', 'CU'], [], False),
        ('IR', 'IR', ['IR', 'CU'], [], False),
        ('US', None, [], ['US'], True),
        ('IR', None, [], ['US'], False),
        ('US', 'IR', [], ['US'], False),
    )
    @ddt.unpack
    def test_country_access_rules(self, ip_country, profile_country, blacklist, whitelist, allow_access):
        # Configure the access rules
        for whitelist_country in whitelist:
            CountryAccessRule.objects.create(
                rule_type=CountryAccessRule.WHITELIST_RULE,
                restricted_course=self.restricted_course,
                country=Country.objects.get(country=whitelist_country)
            )

        for blacklist_country in blacklist:
            CountryAccessRule.objects.create(
                rule_type=CountryAccessRule.BLACKLIST_RULE,
                restricted_course=self.restricted_course,
                country=Country.objects.get(country=blacklist_country)
            )

        # Configure the user's profile country
        if profile_country is not None:
            self.user.profile.country = profile_country
            self.user.profile.save()

        # Appear to make a request from an IP in a particular country
        with self._mock_geoip(ip_country):
            # Call the API.  Note that the IP address we pass in doesn't
            # matter, since we're injecting a mock for geo-location
            result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')

        # Verify that the access rules were applied correctly
        self.assertEqual(result, allow_access)

    def test_no_user_has_access(self):
        CountryAccessRule.objects.create(
            rule_type=CountryAccessRule.BLACKLIST_RULE,
            restricted_course=self.restricted_course,
            country=Country.objects.get(country='US')
        )

        # The user is set to None, because the user has not been authenticated.
        result = embargo_api.check_course_access(self.course.id, ip_address='0.0.0.0')
        self.assertTrue(result)

    def test_no_user_blocked(self):
        CountryAccessRule.objects.create(
            rule_type=CountryAccessRule.BLACKLIST_RULE,
            restricted_course=self.restricted_course,
            country=Country.objects.get(country='US')
        )

        with self._mock_geoip('US'):
            # The user is set to None, because the user has not been authenticated.
            result = embargo_api.check_course_access(self.course.id, ip_address='0.0.0.0')
            self.assertFalse(result)

    def test_course_not_restricted(self):
        # No restricted course model for this course key,
        # so all access checks should be skipped.
        unrestricted_course = CourseFactory.create()
        with self.assertNumQueries(1):
            embargo_api.check_course_access(unrestricted_course.id, user=self.user, ip_address='0.0.0.0')

        # The second check should require no database queries
        with self.assertNumQueries(0):
            embargo_api.check_course_access(unrestricted_course.id, user=self.user, ip_address='0.0.0.0')

    def test_ip_v6(self):
        # Test the scenario that will go through every check
        # (restricted course, but pass all the checks)
        result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='FE80::0202:B3FF:FE1E:8329')
        self.assertTrue(result)

    def test_country_access_fallback_to_continent_code(self):
        # Simulate PyGeoIP falling back to a continent code
        # instead of a country code.  In this case, we should
        # allow the user access.
        with self._mock_geoip('EU'):
            result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')
            self.assertTrue(result)

    @mock.patch.dict(settings.FEATURES, {'EMBARGO': True})
    def test_profile_country_db_null(self):
        # Django country fields treat NULL values inconsistently.
        # When saving a profile with country set to None, Django saves an empty string to the database.
        # However, when the country field loads a NULL value from the database, it sets
        # `country.code` to `None`.  This caused a bug in which country values created by
        # the original South schema migration -- which defaulted to NULL -- caused a runtime
        # exception when the embargo middleware treated the value as a string.
        # In order to simulate this behavior, we can't simply set `profile.country = None`.
        # (because when we save it, it will set the database field to an empty string instead of NULL)
        query = "UPDATE auth_userprofile SET country = NULL WHERE id = %s"
        connection.cursor().execute(query, [str(self.user.profile.id)])
        transaction.commit_unless_managed()

        # Verify that we can check the user's access without error
        result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')
        self.assertTrue(result)

    def test_caching(self):
        with self._mock_geoip('US'):
            # Test the scenario that will go through every check
            # (restricted course, but pass all the checks)
            # This is the worst case, so it will hit all of the
            # caching code.
            with self.assertNumQueries(3):
                embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')

            with self.assertNumQueries(0):
                embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')

    def test_caching_no_restricted_courses(self):
        RestrictedCourse.objects.all().delete()
        cache.clear()

        with self.assertNumQueries(1):
            embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')

        with self.assertNumQueries(0):
            embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')

    @contextmanager
    def _mock_geoip(self, country_code):
        with mock.patch.object(pygeoip.GeoIP, 'country_code_by_addr') as mock_ip:
            mock_ip.return_value = country_code
            yield


@ddt.ddt
@override_settings(MODULESTORE=MODULESTORE_CONFIG)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EmbargoMessageUrlApiTests(UrlResetMixin, ModuleStoreTestCase):
    """Test the embargo API calls for retrieving the blocking message URLs. """

    @patch.dict(settings.FEATURES, {'EMBARGO': True})
    def setUp(self):
        super(EmbargoMessageUrlApiTests, self).setUp('embargo')
        self.course = CourseFactory.create()

    def tearDown(self):
        super(EmbargoMessageUrlApiTests, self).tearDown()
        cache.clear()

    @ddt.data(
        ('enrollment', '/embargo/blocked-message/enrollment/embargo/'),
        ('courseware', '/embargo/blocked-message/courseware/embargo/')
    )
    @ddt.unpack
    def test_message_url_path(self, access_point, expected_url_path):
        self._restrict_course(self.course.id)

        # Retrieve the URL to the blocked message page
        url_path = embargo_api.message_url_path(self.course.id, access_point)
        self.assertEqual(url_path, expected_url_path)

    def test_message_url_path_caching(self):
        self._restrict_course(self.course.id)

        # The first time we retrieve the message, we'll need
        # to hit the database.
        with self.assertNumQueries(2):
            embargo_api.message_url_path(self.course.id, "enrollment")

        # The second time, we should be using cached values
        with self.assertNumQueries(0):
            embargo_api.message_url_path(self.course.id, "enrollment")

    @ddt.data('enrollment', 'courseware')
    def test_message_url_path_no_restrictions_for_course(self, access_point):
        # No restrictions for the course
        url_path = embargo_api.message_url_path(self.course.id, access_point)

        # Use a default path
        self.assertEqual(url_path, '/embargo/blocked-message/courseware/default/')

    def test_invalid_access_point(self):
        with self.assertRaises(InvalidAccessPoint):
            embargo_api.message_url_path(self.course.id, "invalid")

    def test_message_url_stale_cache(self):
        # Retrieve the URL once, populating the cache with the list
        # of restricted courses.
        self._restrict_course(self.course.id)
        embargo_api.message_url_path(self.course.id, 'courseware')

        # Delete the restricted course entry
        RestrictedCourse.objects.get(course_key=self.course.id).delete()

        # Clear the message URL cache
        message_cache_key = (
            'embargo.message_url_path.courseware.{course_key}'
        ).format(course_key=self.course.id)
        cache.delete(message_cache_key)

        # Try again.  Even though the cache results are stale,
        # we should still get a valid URL.
        url_path = embargo_api.message_url_path(self.course.id, 'courseware')
        self.assertEqual(url_path, '/embargo/blocked-message/courseware/default/')

    def _restrict_course(self, course_key):
        """Restrict the user from accessing the course. """
        country = Country.objects.create(country='us')
        restricted_course = RestrictedCourse.objects.create(
            course_key=course_key,
            enroll_msg_key='embargo',
            access_msg_key='embargo'
        )
        CountryAccessRule.objects.create(
            restricted_course=restricted_course,
            rule_type=CountryAccessRule.BLACKLIST_RULE,
            country=country
        )