Commit 5aa450ec by Ayub khan Committed by GitHub

Merge pull request #16279 from edx/LEARNER-717-2

Management Command to clear DOT expired data
parents 00d8b2a4 790150a8
from __future__ import unicode_literals
import logging
from datetime import timedelta
from time import sleep
from django.core.exceptions import ImproperlyConfigured
from django.core.management.base import BaseCommand
from django.db import transaction
from django.utils import timezone
from oauth2_provider.models import AccessToken, Grant, RefreshToken
from oauth2_provider.settings import oauth2_settings
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = "Clear expired access tokens and refresh tokens for Django OAuth Toolkit"
def add_arguments(self, parser):
parser.add_argument('--batch_size',
action='store',
dest='batch_size',
type=int,
default=1000,
help='Maximum number of database rows to delete per query. '
'This helps avoid locking the database when deleting large amounts of data.')
parser.add_argument('--sleep_time',
action='store',
dest='sleep_time',
type=int,
default=10,
help='Sleep time between deletion of batches')
def clear_table_data(self, query_set, batch_size, model, sleep_time):
message = 'Cleaning {} rows from {} table'.format(query_set.count(), model.__name__)
logger.info(message)
while query_set.exists():
qs = query_set[:batch_size]
batch_ids = qs.values_list('id', flat=True)
with transaction.atomic():
model.objects.filter(pk__in=list(batch_ids)).delete()
if query_set.exists():
sleep(sleep_time)
def get_expiration_time(self, now):
refresh_token_expire_seconds = oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS
if not isinstance(refresh_token_expire_seconds, timedelta):
try:
refresh_token_expire_seconds = timedelta(seconds=refresh_token_expire_seconds)
except TypeError:
e = "REFRESH_TOKEN_EXPIRE_SECONDS must be either a timedelta or seconds"
raise ImproperlyConfigured(e)
return now - refresh_token_expire_seconds
def handle(self, *args, **options):
batch_size = options['batch_size']
sleep_time = options['sleep_time']
now = timezone.now()
refresh_expire_at = self.get_expiration_time(now)
query_set = RefreshToken.objects.filter(access_token__expires__lt=refresh_expire_at)
self.clear_table_data(query_set, batch_size, RefreshToken, sleep_time)
query_set = AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now)
self.clear_table_data(query_set, batch_size, AccessToken, sleep_time)
query_set = Grant.objects.filter(expires__lt=now)
self.clear_table_data(query_set, batch_size, Grant, sleep_time)
from datetime import timedelta
import unittest
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.management import call_command
from django.db.models import QuerySet
from django.test import TestCase
from django.utils import timezone
from mock import patch
from oauth2_provider.models import AccessToken
from testfixtures import LogCapture
from openedx.core.djangoapps.oauth_dispatch.tests import factories
from student.tests.factories import UserFactory
LOGGER_NAME = 'openedx.core.djangoapps.oauth_dispatch.management.commands.edx_clear_expired_tokens'
def counter(fn):
"""
Adds a call counter to the given function.
Source: http://code.activestate.com/recipes/577534-counting-decorator/
"""
def _counted(*largs, **kargs):
_counted.invocations += 1
fn(*largs, **kargs)
_counted.invocations = 0
return _counted
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class EdxClearExpiredTokensTests(TestCase):
@patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 'xyz')
def test_invalid_expiration_time(self):
with LogCapture(LOGGER_NAME) as log:
with self.assertRaises(ImproperlyConfigured):
call_command('edx_clear_expired_tokens')
log.check(
(
LOGGER_NAME,
'EXCEPTION',
'REFRESH_TOKEN_EXPIRE_SECONDS must be either a timedelta or seconds'
)
)
@patch('oauth2_provider.settings.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS', 3600)
def test_clear_expired_tokens(self):
initial_count = 5
now = timezone.now()
expires = now - timedelta(days=1)
users = UserFactory.create_batch(initial_count)
for user in users:
application = factories.ApplicationFactory(user=user)
factories.AccessTokenFactory(user=user, application=application, expires=expires)
self.assertEqual(
AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(),
initial_count
)
QuerySet.delete = counter(QuerySet.delete)
call_command('edx_clear_expired_tokens', batch_size=1, sleep_time=0)
self.assertEqual(QuerySet.delete.invocations, initial_count)
self.assertEqual(AccessToken.objects.filter(refresh_token__isnull=True, expires__lt=now).count(), 0)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment