Commit 48d2191c by Peter Fogg

Merge pull request #30 from edx/peter-fogg/rate-limiting

Rate limiting.
parents bcb8d271 32a2b260
......@@ -4,7 +4,8 @@ from django.contrib import admin
from django.contrib.auth.admin import UserAdmin
from django.utils.translation import ugettext_lazy as _
from course_discovery.apps.core.models import User
from course_discovery.apps.core.forms import UserThrottleRateForm
from course_discovery.apps.core.models import User, UserThrottleRate
class CustomUserAdmin(UserAdmin):
......@@ -19,4 +20,11 @@ class CustomUserAdmin(UserAdmin):
)
class UserThrottleRateAdmin(admin.ModelAdmin):
""" Admin configuration for the UserThrottleRate model. """
form = UserThrottleRateForm
raw_id_fields = ('user',)
admin.site.register(User, CustomUserAdmin)
admin.site.register(UserThrottleRate, UserThrottleRateAdmin)
""" Core forms. """
from django import forms
from django.utils.translation import ugettext_lazy as _
from course_discovery.apps.core.models import UserThrottleRate
class UserThrottleRateForm(forms.ModelForm):
"""Form for the UserThrottleRate admin."""
class Meta:
model = UserThrottleRate
fields = ('user', 'rate')
def clean_rate(self):
rate = self.cleaned_data.get('rate')
if rate:
try:
num, period = rate.split('/')
int(num) # Only evaluated for the (possible) side effect of a ValueError
period_choices = ('second', 'minute', 'hour', 'day')
if period not in period_choices:
# Translators: 'period_choices' is a list of possible values, like ('second', 'minute', 'hour')
error_msg = _("period must be one of {period_choices}.").format(period_choices=period_choices)
raise forms.ValidationError(error_msg)
except ValueError:
error_msg = _("'rate' must be in the format defined by DRF, such as '100/hour'.")
raise forms.ValidationError(error_msg)
return rate
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
from django.conf import settings
class Migration(migrations.Migration):
dependencies = [
('core', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='UserThrottleRate',
fields=[
('id', models.AutoField(verbose_name='ID', primary_key=True, auto_created=True, serialize=False)),
('rate', models.CharField(max_length=50)),
('user', models.ForeignKey(to=settings.AUTH_USER_MODEL)),
],
),
]
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0002_userthrottlerate'),
]
operations = [
migrations.AlterField(
model_name='userthrottlerate',
name='rate',
field=models.CharField(help_text='The rate of requests to limit this user to. The format is specified by Django Rest Framework (see http://www.django-rest-framework.org/api-guide/throttling/).', max_length=50),
),
]
......@@ -25,3 +25,15 @@ class User(AbstractUser):
def get_full_name(self):
return self.full_name or super(User, self).get_full_name()
class UserThrottleRate(models.Model):
"""Model for configuring a rate limit per-user."""
user = models.ForeignKey(User)
rate = models.CharField(
max_length=50,
help_text=_(
'The rate of requests to limit this user to. The format is specified by Django'
' Rest Framework (see http://www.django-rest-framework.org/api-guide/throttling/).')
)
"""Tests for core forms."""
import ddt
from django.test import TestCase
from course_discovery.apps.core.forms import UserThrottleRateForm
from course_discovery.apps.core.tests.factories import UserFactory
@ddt.ddt
class UserThrottleRateFormTest(TestCase):
"""Tests for the UserThrottleRate admin form."""
def setUp(self):
super(UserThrottleRateFormTest, self).setUp()
self.user = UserFactory()
def test_form_valid(self):
form = UserThrottleRateForm({'rate': '100/day', 'user': self.user.id})
self.assertTrue(form.is_valid())
@ddt.data(
('100', ["'rate' must be in the format defined by DRF, such as '100/hour'."]),
('100/fortnight', ["period must be one of ('second', 'minute', 'hour', 'day')."]),
('foo/day', ["'rate' must be in the format defined by DRF, such as '100/hour'."]),
(None, ['This field is required.']),
)
@ddt.unpack
def test_form_invalid_rate(self, rate, expected_error):
form = UserThrottleRateForm({'rate': rate, 'user': self.user.id})
self.assertFalse(form.is_valid())
self.assertEqual(form.errors, {
'rate': expected_error
})
from django.core.cache import cache
from django.core.urlresolvers import reverse
from rest_framework.test import APITestCase
from course_discovery.apps.core.models import UserThrottleRate
from course_discovery.apps.core.tests.factories import UserFactory, USER_PASSWORD
from course_discovery.apps.core.throttles import OverridableUserRateThrottle
class RateLimitingTest(APITestCase):
"""
Testing rate limiting of API calls.
"""
def setUp(self):
super(RateLimitingTest, self).setUp()
self.url = reverse('api:v1:course-list')
self.user = UserFactory()
self.client.login(username=self.user.username, password=USER_PASSWORD)
def tearDown(self):
"""
Clear the cache, since DRF uses it for recording requests against a
URL. Django does not clear the cache between test runs.
"""
super(RateLimitingTest, self).tearDown()
cache.clear()
def _make_requests(self):
num_requests = OverridableUserRateThrottle().num_requests
for __ in range(num_requests + 1):
response = self.client.get(self.url)
return response
def test_rate_limiting(self):
response = self._make_requests()
self.assertEqual(response.status_code, 429)
def test_user_throttle_rate(self):
UserThrottleRate.objects.create(user=self.user, rate='1000/day')
response = self._make_requests()
self.assertEqual(response.status_code, 200)
def test_superuser_throttling(self):
self.user.is_superuser = True
self.user.is_staff = True
self.user.save()
response = self._make_requests()
self.assertEqual(response.status_code, 200)
"""Custom API throttles."""
from rest_framework.throttling import UserRateThrottle
from course_discovery.apps.core.models import UserThrottleRate
class OverridableUserRateThrottle(UserRateThrottle):
"""Rate throttling of requests, overridable on a per-user basis."""
def allow_request(self, request, view):
user = request.user
if user.is_superuser:
return True
try:
# Override this throttle's rate if applicable
user_throttle = UserThrottleRate.objects.get(user=user)
self.rate = user_throttle.rate
self.num_requests, self.duration = self.parse_rate(self.rate)
except UserThrottleRate.DoesNotExist:
pass
return super(OverridableUserRateThrottle, self).allow_request(request, view)
......@@ -249,7 +249,13 @@ REST_FRAMEWORK = {
'rest_framework.renderers.MultiPartRenderer',
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
)
),
'DEFAULT_THROTTLE_CLASSES': (
'course_discovery.apps.core.throttles.OverridableUserRateThrottle',
),
'DEFAULT_THROTTLE_RATES': {
'user': '100/hour',
},
}
# NOTE (CCB): JWT_SECRET_KEY is intentionally not set here to avoid production releases with a public value.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment