import unittest
from datetime import timedelta

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.management import call_command
from django.db.models import QuerySet
from django.test import TestCase
from django.test.utils import override_settings
from django.utils import timezone
from mock import patch
from oauth2_provider.models import AccessToken, RefreshToken
from testfixtures import LogCapture

from openedx.core.djangoapps.oauth_dispatch.tests import factories
from student.tests.factories import UserFactory

LOGGER_NAME = 'openedx.core.djangoapps.oauth_dispatch.management.commands.edx_clear_expired_tokens'


def counter(fn):
    """
    Adds a call counter to the given function.
    Source: http://code.activestate.com/recipes/577534-counting-decorator/
    """
    def _counted(*largs, **kargs):
        _counted.invocations += 1
        fn(*largs, **kargs)

    _counted.invocations = 0
    return _counted


@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EdxClearExpiredTokensTests(TestCase):

    # patching REFRESH_TOKEN_EXPIRE_SECONDS because override_settings not working.
    @patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz')
    def test_invalid_expiration_time(self):
        with LogCapture(LOGGER_NAME) as log:
            with self.assertRaises(ImproperlyConfigured):
                call_command('edx_clear_expired_tokens')
                log.check(
                    (
                        LOGGER_NAME,
                        'EXCEPTION',
                        'REFRESH_TOKEN_EXPIRE_SECONDS must be either a timedelta or seconds'
                    )
                )

    @override_settings()
    def test_excluded_application_ids(self):
        settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600
        expires = timezone.now() - timedelta(days=1)
        application = factories.ApplicationFactory()
        access_token = factories.AccessTokenFactory(user=application.user, application=application, expires=expires)
        factories.RefreshTokenFactory(user=application.user, application=application, access_token=access_token)
        with LogCapture(LOGGER_NAME) as log:
            call_command('edx_clear_expired_tokens', sleep_time=0, excluded_application_ids=str(application.id))
            log.check(
                (
                    LOGGER_NAME,
                    'INFO',
                    'Cleaning {} rows from {} table'.format(0, RefreshToken.__name__)
                ),
                (
                    LOGGER_NAME,
                    'INFO',
                    'Cleaning {} rows from {} table'.format(0, AccessToken.__name__),
                ),
                (
                    LOGGER_NAME,
                    'INFO',
                    'Cleaning 0 rows from Grant table',
                )
            )
        self.assertTrue(RefreshToken.objects.filter(application=application).exists())

    @override_settings()
    def test_clear_expired_tokens(self):
        settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600
        initial_count = 5
        now = timezone.now()
        expires = now - timedelta(days=1)
        users = UserFactory.create_batch(initial_count)
        for user in users:
            application = factories.ApplicationFactory(user=user)
            factories.AccessTokenFactory(user=user, application=application, expires=expires)
        self.assertEqual(
            AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(),
            initial_count
        )
        QuerySet.delete = counter(QuerySet.delete)

        call_command('edx_clear_expired_tokens', batch_size=1, sleep_time=0)
        self.assertEqual(QuerySet.delete.invocations, initial_count)
        self.assertEqual(AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(), 0)
