Commit 320eb0c6 by ayub-khan

edx_clear_expired_tokens management commands removes

all expired tokens added a exlude_application_ids argument which
enable us to not remove expired tokens for given applications.
LEARNER-717
parent b964b528
...@@ -31,6 +31,12 @@ class Command(BaseCommand): ...@@ -31,6 +31,12 @@ class Command(BaseCommand):
type=int, type=int,
default=10, default=10,
help='Sleep time between deletion of batches') help='Sleep time between deletion of batches')
parser.add_argument('--excluded-application-ids',
action='store',
dest='excluded-application-ids',
type=str,
default='',
help='Comma-separated list of application IDs for which tokens will NOT be removed')
def clear_table_data(self, query_set, batch_size, model, sleep_time): def clear_table_data(self, query_set, batch_size, model, sleep_time):
message = 'Cleaning {} rows from {} table'.format(query_set.count(), model.__name__) message = 'Cleaning {} rows from {} table'.format(query_set.count(), model.__name__)
...@@ -57,11 +63,16 @@ class Command(BaseCommand): ...@@ -57,11 +63,16 @@ class Command(BaseCommand):
def handle(self, *args, **options): def handle(self, *args, **options):
batch_size = options['batch_size'] batch_size = options['batch_size']
sleep_time = options['sleep_time'] sleep_time = options['sleep_time']
if options['excluded-application-ids']:
excluded_application_ids = [int(x) for x in options['excluded-application-ids'].split(',')]
else:
excluded_application_ids = []
now = timezone.now() now = timezone.now()
refresh_expire_at = self.get_expiration_time(now) refresh_expire_at = self.get_expiration_time(now)
query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at) query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at).exclude(
application_id__in=excluded_application_ids)
self.clear_table_data(query_set, batch_size, RefreshToken, sleep_time) self.clear_table_data(query_set, batch_size, RefreshToken, sleep_time)
query_set = AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now) query_set = AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now)
......
from datetime import timedelta
import unittest import unittest
from datetime import timedelta
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.management import call_command from django.core.management import call_command
from django.db.models import QuerySet from django.db.models import QuerySet
from django.test import TestCase from django.test import TestCase
from django.test.utils import override_settings
from django.utils import timezone from django.utils import timezone
from mock import patch from mock import patch
from oauth2_provider.models import AccessToken from oauth2_provider.models import AccessToken, RefreshToken
from testfixtures import LogCapture from testfixtures import LogCapture
from openedx.core.djangoapps.oauth_dispatch.tests import factories from openedx.core.djangoapps.oauth_dispatch.tests import factories
...@@ -33,6 +34,7 @@ def counter(fn): ...@@ -33,6 +34,7 @@ def counter(fn):
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EdxClearExpiredTokensTests(TestCase): class EdxClearExpiredTokensTests(TestCase):
# patching REFRESH_TOKEN_EXPIRE_SECONDS because override_settings not working.
@patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz') @patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz')
def test_invalid_expiration_time(self): def test_invalid_expiration_time(self):
with LogCapture(LOGGER_NAME) as log: with LogCapture(LOGGER_NAME) as log:
...@@ -46,8 +48,37 @@ class EdxClearExpiredTokensTests(TestCase): ...@@ -46,8 +48,37 @@ class EdxClearExpiredTokensTests(TestCase):
) )
) )
@patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 3600) @override_settings()
def test_excluded_application_ids(self):
settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600
expires = timezone.now() - timedelta(days=1)
application = factories.ApplicationFactory()
access_token = factories.AccessTokenFactory(user=application.user, application=application, expires=expires)
factories.RefreshTokenFactory(user=application.user, application=application, access_token=access_token)
with LogCapture(LOGGER_NAME) as log:
call_command('edx_clear_expired_tokens', sleep_time=0, excluded_application_ids=str(application.id))
log.check(
(
LOGGER_NAME,
'INFO',
'Cleaning {} rows from {} table'.format(0, RefreshToken.__name__)
),
(
LOGGER_NAME,
'INFO',
'Cleaning {} rows from {} table'.format(0, AccessToken.__name__),
),
(
LOGGER_NAME,
'INFO',
'Cleaning 0 rows from Grant table',
)
)
self.assertTrue(RefreshToken.objects.filter(application=application).exists())
@override_settings()
def test_clear_expired_tokens(self): def test_clear_expired_tokens(self):
settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 3600
initial_count = 5 initial_count = 5
now = timezone.now() now = timezone.now()
expires = now - timedelta(days=1) expires = now - timedelta(days=1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment