Commit 00473d44 by Brian Beggs

Merge pull request #9929 from edx/bbeggs/merge-DRF-3.1

Merge DRF 3.1 in to master
parents 78f8fc39 53524efb
...@@ -44,7 +44,10 @@ from lms.envs.common import ( ...@@ -44,7 +44,10 @@ from lms.envs.common import (
PROFILE_IMAGE_SECRET_KEY, PROFILE_IMAGE_MIN_BYTES, PROFILE_IMAGE_MAX_BYTES, PROFILE_IMAGE_SECRET_KEY, PROFILE_IMAGE_MIN_BYTES, PROFILE_IMAGE_MAX_BYTES,
# The following setting is included as it is used to check whether to # The following setting is included as it is used to check whether to
# display credit eligibility table on the CMS or not. # display credit eligibility table on the CMS or not.
ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY,
# Django REST framework configuration
REST_FRAMEWORK,
) )
from path import Path as path from path import Path as path
from warnings import simplefilter from warnings import simplefilter
......
...@@ -6,7 +6,7 @@ from django.test.utils import override_settings ...@@ -6,7 +6,7 @@ from django.test.utils import override_settings
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.conf import settings from django.conf import settings
from rest_framework.exceptions import AuthenticationFailed from rest_framework.exceptions import PermissionDenied
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
...@@ -24,7 +24,7 @@ class CrossDomainAuthTest(TestCase): ...@@ -24,7 +24,7 @@ class CrossDomainAuthTest(TestCase):
def test_perform_csrf_referer_check(self): def test_perform_csrf_referer_check(self):
request = self._fake_request() request = self._fake_request()
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'): with self.assertRaisesRegexp(PermissionDenied, 'CSRF'):
self.auth.enforce_csrf(request) self.auth.enforce_csrf(request)
@patch.dict(settings.FEATURES, { @patch.dict(settings.FEATURES, {
......
...@@ -11,7 +11,7 @@ from enrollment.errors import ( ...@@ -11,7 +11,7 @@ from enrollment.errors import (
CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError, CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError,
CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute
) )
from enrollment.serializers import CourseEnrollmentSerializer, CourseField from enrollment.serializers import CourseEnrollmentSerializer, CourseSerializer
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from student.models import ( from student.models import (
CourseEnrollment, NonExistentCourseError, EnrollmentClosedError, CourseEnrollment, NonExistentCourseError, EnrollmentClosedError,
...@@ -35,9 +35,30 @@ def get_course_enrollments(user_id): ...@@ -35,9 +35,30 @@ def get_course_enrollments(user_id):
""" """
qset = CourseEnrollment.objects.filter( qset = CourseEnrollment.objects.filter(
user__username=user_id, is_active=True user__username=user_id,
is_active=True
).order_by('created') ).order_by('created')
return CourseEnrollmentSerializer(qset).data
enrollments = CourseEnrollmentSerializer(qset, many=True).data
# Find deleted courses and filter them out of the results
deleted = []
valid = []
for enrollment in enrollments:
if enrollment.get("course_details") is not None:
valid.append(enrollment)
else:
deleted.append(enrollment)
if deleted:
log.warning(
(
u"Course enrollments for user %s reference "
u"courses that do not exist (this can occur if a course is deleted)."
), user_id,
)
return valid
def get_course_enrollment(username, course_id): def get_course_enrollment(username, course_id):
...@@ -271,4 +292,4 @@ def get_course_enrollment_info(course_id, include_expired=False): ...@@ -271,4 +292,4 @@ def get_course_enrollment_info(course_id, include_expired=False):
log.warning(msg) log.warning(msg)
raise CourseNotFoundError(msg) raise CourseNotFoundError(msg)
else: else:
return CourseField().to_native(course, include_expired=include_expired) return CourseSerializer(course, include_expired=include_expired).data
...@@ -30,32 +30,36 @@ class StringListField(serializers.CharField): ...@@ -30,32 +30,36 @@ class StringListField(serializers.CharField):
return [int(item) for item in items] return [int(item) for item in items]
class CourseField(serializers.RelatedField): class CourseSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""Read-Only representation of course enrollment information. """
Serialize a course descriptor and related information.
"""
Aggregates course information from the CourseDescriptor as well as the Course Modes configured course_id = serializers.CharField(source="id")
for enrolling in the course. enrollment_start = serializers.DateTimeField(format=None)
enrollment_end = serializers.DateTimeField(format=None)
course_start = serializers.DateTimeField(source="start", format=None)
course_end = serializers.DateTimeField(source="end", format=None)
invite_only = serializers.BooleanField(source="invitation_only")
course_modes = serializers.SerializerMethodField()
""" def __init__(self, *args, **kwargs):
self.include_expired = kwargs.pop("include_expired", False)
super(CourseSerializer, self).__init__(*args, **kwargs)
def to_native(self, course, **kwargs): def get_course_modes(self, obj):
course_modes = ModeSerializer( """
CourseMode.modes_for_course( Retrieve course modes associated with the course.
course.id, """
include_expired=kwargs.get('include_expired', False), course_modes = CourseMode.modes_for_course(
only_selectable=False obj.id,
) include_expired=self.include_expired,
).data only_selectable=False
)
return { return [
'course_id': unicode(course.id), ModeSerializer(mode).data
'enrollment_start': course.enrollment_start, for mode in course_modes
'enrollment_end': course.enrollment_end, ]
'course_start': course.start,
'course_end': course.end,
'invite_only': course.invitation_only,
'course_modes': course_modes,
}
class CourseEnrollmentSerializer(serializers.ModelSerializer): class CourseEnrollmentSerializer(serializers.ModelSerializer):
...@@ -65,34 +69,9 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer): ...@@ -65,34 +69,9 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer):
the Course Descriptor and course modes, to give a complete representation of course enrollment. the Course Descriptor and course modes, to give a complete representation of course enrollment.
""" """
course_details = serializers.SerializerMethodField('get_course_details') course_details = CourseSerializer(source="course_overview")
user = serializers.SerializerMethodField('get_username') user = serializers.SerializerMethodField('get_username')
@property
def data(self):
serialized_data = super(CourseEnrollmentSerializer, self).data
# filter the results with empty courses 'course_details'
if isinstance(serialized_data, dict):
if serialized_data.get('course_details') is None:
return None
return serialized_data
return [enrollment for enrollment in serialized_data if enrollment.get('course_details')]
def get_course_details(self, model):
if model.course is None:
msg = u"Course '{0}' does not exist (maybe deleted), in which User (user_id: '{1}') is enrolled.".format(
model.course_id,
model.user.id
)
log.warning(msg)
return None
field = CourseField()
return field.to_native(model.course)
def get_username(self, model): def get_username(self, model):
"""Retrieves the username from the associated model.""" """Retrieves the username from the associated model."""
return model.username return model.username
......
...@@ -1038,7 +1038,7 @@ class EnrollmentCrossDomainTest(ModuleStoreTestCase): ...@@ -1038,7 +1038,7 @@ class EnrollmentCrossDomainTest(ModuleStoreTestCase):
@cross_domain_config @cross_domain_config
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
resp = self._cross_domain_post('invalid_csrf_token') resp = self._cross_domain_post('invalid_csrf_token')
self.assertEqual(resp.status_code, 401) self.assertEqual(resp.status_code, 403)
def _get_csrf_cookie(self): def _get_csrf_cookie(self):
"""Retrieve the cross-domain CSRF cookie. """ """Retrieve the cross-domain CSRF cookie. """
......
...@@ -5,10 +5,18 @@ This module requires that :class:`request_cache.middleware.RequestCache` ...@@ -5,10 +5,18 @@ This module requires that :class:`request_cache.middleware.RequestCache`
is installed in order to clear the cache after each request. is installed in order to clear the cache after each request.
""" """
import logging
from urlparse import urlparse
from django.conf import settings
from django.test.client import RequestFactory
from request_cache import middleware from request_cache import middleware
log = logging.getLogger(__name__)
def get_cache(name): def get_cache(name):
""" """
Return the request cache named ``name``. Return the request cache named ``name``.
...@@ -26,3 +34,38 @@ def get_request(): ...@@ -26,3 +34,38 @@ def get_request():
Return the current request. Return the current request.
""" """
return middleware.RequestCache.get_current_request() return middleware.RequestCache.get_current_request()
def get_request_or_stub():
"""
Return the current request or a stub request.
If called outside the context of a request, construct a fake
request that can be used to build an absolute URI.
This is useful in cases where we need to pass in a request object
but don't have an active request (for example, in test cases).
"""
request = get_request()
if request is None:
log.warning(
"Could not retrieve the current request. "
"A stub request will be created instead using settings.SITE_NAME. "
"This should be used *only* in test cases, never in production!"
)
# The settings SITE_NAME may contain a port number, so we need to
# parse the full URL.
full_url = "http://{site_name}".format(site_name=settings.SITE_NAME)
parsed_url = urlparse(full_url)
# Construct the fake request. This can be used to construct absolute
# URIs to other paths.
return RequestFactory(
SERVER_NAME=parsed_url.hostname,
SERVER_PORT=parsed_url.port or 80,
).get("/")
else:
return request
"""
Tests for the request cache.
"""
from django.conf import settings
from django.test import TestCase
from request_cache import get_request_or_stub
class TestRequestCache(TestCase):
"""
Tests for the request cache.
"""
def test_get_request_or_stub(self):
# Outside the context of the request, we should still get a request
# that allows us to build an absolute URI.
stub = get_request_or_stub()
expected_url = "http://{site_name}/foobar".format(site_name=settings.SITE_NAME)
self.assertEqual(stub.build_absolute_uri("foobar"), expected_url)
...@@ -406,7 +406,7 @@ class BrowseTopicsTest(TeamsTabBase): ...@@ -406,7 +406,7 @@ class BrowseTopicsTest(TeamsTabBase):
) )
create_team_page.submit_form() create_team_page.submit_form()
team_page = TeamPage(self.browser, self.course_id) team_page = TeamPage(self.browser, self.course_id)
self.assertTrue(team_page.is_browser_on_page) self.assertTrue(team_page.is_browser_on_page())
team_page.click_all_topics() team_page.click_all_topics()
self.assertTrue(self.topics_page.is_browser_on_page()) self.assertTrue(self.topics_page.is_browser_on_page())
self.topics_page.wait_for_ajax() self.topics_page.wait_for_ajax()
......
...@@ -18,7 +18,12 @@ class CourseModeSerializer(serializers.ModelSerializer): ...@@ -18,7 +18,12 @@ class CourseModeSerializer(serializers.ModelSerializer):
""" CourseMode serializer. """ """ CourseMode serializer. """
name = serializers.CharField(source='mode_slug') name = serializers.CharField(source='mode_slug')
price = serializers.IntegerField(source='min_price') price = serializers.IntegerField(source='min_price')
expires = serializers.DateTimeField(source='expiration_datetime', required=False, blank=True) expires = serializers.DateTimeField(
source='expiration_datetime',
required=False,
allow_null=True,
format=None
)
def get_identity(self, data): def get_identity(self, data):
try: try:
...@@ -56,8 +61,8 @@ class CourseSerializer(serializers.Serializer): ...@@ -56,8 +61,8 @@ class CourseSerializer(serializers.Serializer):
""" Course serializer. """ """ Course serializer. """
id = serializers.CharField(validators=[validate_course_id]) # pylint: disable=invalid-name id = serializers.CharField(validators=[validate_course_id]) # pylint: disable=invalid-name
name = serializers.CharField(read_only=True) name = serializers.CharField(read_only=True)
verification_deadline = serializers.DateTimeField(blank=True) verification_deadline = serializers.DateTimeField(format=None, allow_null=True, required=False)
modes = CourseModeSerializer(many=True, allow_add_remove=True) modes = CourseModeSerializer(many=True)
def validate(self, attrs): def validate(self, attrs):
""" Ensure the verification deadline occurs AFTER the course mode enrollment deadlines. """ """ Ensure the verification deadline occurs AFTER the course mode enrollment deadlines. """
...@@ -68,7 +73,7 @@ class CourseSerializer(serializers.Serializer): ...@@ -68,7 +73,7 @@ class CourseSerializer(serializers.Serializer):
# Find the earliest upgrade deadline # Find the earliest upgrade deadline
for mode in attrs['modes']: for mode in attrs['modes']:
expires = mode.expiration_datetime expires = mode.get("expiration_datetime")
if expires: if expires:
# If we don't already have an upgrade_deadline value, use datetime.max so that we can actually # If we don't already have an upgrade_deadline value, use datetime.max so that we can actually
# complete the comparison. # complete the comparison.
...@@ -82,9 +87,28 @@ class CourseSerializer(serializers.Serializer): ...@@ -82,9 +87,28 @@ class CourseSerializer(serializers.Serializer):
return attrs return attrs
def restore_object(self, attrs, instance=None): def create(self, validated_data):
if instance is None: """Create course modes for a course. """
return Course(attrs['id'], attrs['modes'], attrs['verification_deadline']) course = Course(
validated_data["id"],
self._new_course_mode_models(validated_data["modes"]),
verification_deadline=validated_data["verification_deadline"]
)
course.save()
return course
def update(self, instance, validated_data):
"""Update course modes for an existing course. """
validated_data["modes"] = self._new_course_mode_models(validated_data["modes"])
instance.update(attrs) instance.update(validated_data)
instance.save()
return instance return instance
@staticmethod
def _new_course_mode_models(modes_data):
"""Convert validated course mode data to CourseMode objects. """
return [
CourseMode(**modes_dict)
for modes_dict in modes_data
]
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import logging import logging
from django.http import Http404 from django.http import Http404
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
...@@ -10,6 +11,7 @@ from commerce.api.v1.models import Course ...@@ -10,6 +11,7 @@ from commerce.api.v1.models import Course
from commerce.api.v1.permissions import ApiKeyOrModelPermission from commerce.api.v1.permissions import ApiKeyOrModelPermission
from commerce.api.v1.serializers import CourseSerializer from commerce.api.v1.serializers import CourseSerializer
from course_modes.models import CourseMode from course_modes.models import CourseMode
from openedx.core.lib.api.mixins import PutAsCreateMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -19,12 +21,13 @@ class CourseListView(ListAPIView): ...@@ -19,12 +21,13 @@ class CourseListView(ListAPIView):
authentication_classes = (OAuth2Authentication, SessionAuthentication,) authentication_classes = (OAuth2Authentication, SessionAuthentication,)
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = CourseSerializer serializer_class = CourseSerializer
pagination_class = None
def get_queryset(self): def get_queryset(self):
return Course.iterator() return list(Course.iterator())
class CourseRetrieveUpdateView(RetrieveUpdateAPIView): class CourseRetrieveUpdateView(PutAsCreateMixin, RetrieveUpdateAPIView):
""" Retrieve, update, or create courses/modes. """ """ Retrieve, update, or create courses/modes. """
lookup_field = 'id' lookup_field = 'id'
lookup_url_kwarg = 'course_id' lookup_url_kwarg = 'course_id'
...@@ -33,6 +36,11 @@ class CourseRetrieveUpdateView(RetrieveUpdateAPIView): ...@@ -33,6 +36,11 @@ class CourseRetrieveUpdateView(RetrieveUpdateAPIView):
permission_classes = (ApiKeyOrModelPermission,) permission_classes = (ApiKeyOrModelPermission,)
serializer_class = CourseSerializer serializer_class = CourseSerializer
# Django Rest Framework v3 requires that we provide a queryset.
# Note that we're overriding `get_object()` below to return a `Course`
# rather than a CourseMode, so this isn't really used.
queryset = CourseMode.objects.all()
def get_object(self, queryset=None): def get_object(self, queryset=None):
course_id = self.kwargs.get(self.lookup_url_kwarg) course_id = self.kwargs.get(self.lookup_url_kwarg)
course = Course.get(course_id) course = Course.get(course_id)
......
...@@ -11,11 +11,11 @@ class CourseSerializer(serializers.Serializer): ...@@ -11,11 +11,11 @@ class CourseSerializer(serializers.Serializer):
id = serializers.CharField() # pylint: disable=invalid-name id = serializers.CharField() # pylint: disable=invalid-name
name = serializers.CharField(source='display_name') name = serializers.CharField(source='display_name')
category = serializers.CharField() category = serializers.CharField()
org = serializers.SerializerMethodField('get_org') org = serializers.SerializerMethodField()
run = serializers.SerializerMethodField('get_run') run = serializers.SerializerMethodField()
course = serializers.SerializerMethodField('get_course') course = serializers.SerializerMethodField()
uri = serializers.SerializerMethodField('get_uri') uri = serializers.SerializerMethodField()
image_url = serializers.SerializerMethodField('get_image_url') image_url = serializers.SerializerMethodField()
start = serializers.DateTimeField() start = serializers.DateTimeField()
end = serializers.DateTimeField() end = serializers.DateTimeField()
......
...@@ -36,6 +36,23 @@ class CourseViewTestsMixin(object): ...@@ -36,6 +36,23 @@ class CourseViewTestsMixin(object):
""" """
view = None view = None
raw_grader = [
{
"min_count": 24,
"weight": 0.2,
"type": "Homework",
"drop_count": 0,
"short_label": "HW"
},
{
"min_count": 4,
"weight": 0.8,
"type": "Exam",
"drop_count": 0,
"short_label": "Exam"
}
]
def setUp(self): def setUp(self):
super(CourseViewTestsMixin, self).setUp() super(CourseViewTestsMixin, self).setUp()
self.create_user_and_access_token() self.create_user_and_access_token()
...@@ -51,22 +68,7 @@ class CourseViewTestsMixin(object): ...@@ -51,22 +68,7 @@ class CourseViewTestsMixin(object):
@classmethod @classmethod
def create_course_data(cls): def create_course_data(cls):
cls.invalid_course_id = 'foo/bar/baz' cls.invalid_course_id = 'foo/bar/baz'
cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=[ cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=cls.raw_grader)
{
"min_count": 24,
"weight": 0.2,
"type": "Homework",
"drop_count": 0,
"short_label": "HW"
},
{
"min_count": 4,
"weight": 0.8,
"type": "Exam",
"drop_count": 0,
"short_label": "Exam"
}
])
cls.course_id = unicode(cls.course.id) cls.course_id = unicode(cls.course.id)
with cls.store.bulk_operations(cls.course.id, emit_signals=False): with cls.store.bulk_operations(cls.course.id, emit_signals=False):
cls.sequential = ItemFactory.create( cls.sequential = ItemFactory.create(
...@@ -408,6 +410,55 @@ class CourseGradingPolicyTests(CourseDetailTestMixin, CourseViewTestsMixin, Shar ...@@ -408,6 +410,55 @@ class CourseGradingPolicyTests(CourseDetailTestMixin, CourseViewTestsMixin, Shar
self.assertListEqual(response.data, expected) self.assertListEqual(response.data, expected)
class CourseGradingPolicyMissingFieldsTests(CourseDetailTestMixin, CourseViewTestsMixin, SharedModuleStoreTestCase):
view = 'course_structure_api:v0:grading_policy'
# Update the raw grader to have missing keys
raw_grader = [
{
"min_count": 24,
"weight": 0.2,
"type": "Homework",
"drop_count": 0,
"short_label": "HW"
},
{
# Deleted "min_count" key
"weight": 0.8,
"type": "Exam",
"drop_count": 0,
"short_label": "Exam"
}
]
@classmethod
def setUpClass(cls):
super(CourseGradingPolicyMissingFieldsTests, cls).setUpClass()
cls.create_course_data()
def test_get(self):
"""
The view should return grading policy for a course.
"""
response = super(CourseGradingPolicyMissingFieldsTests, self).test_get()
expected = [
{
"count": 24,
"weight": 0.2,
"assignment_type": "Homework",
"dropped": 0
},
{
"count": None,
"weight": 0.8,
"assignment_type": "Exam",
"dropped": 0
}
]
self.assertListEqual(response.data, expected)
##################################################################################### #####################################################################################
# #
# The following Mixins/Classes collectively test the CourseBlocksAndNavigation view. # The following Mixins/Classes collectively test the CourseBlocksAndNavigation view.
......
...@@ -6,7 +6,8 @@ import logging ...@@ -6,7 +6,8 @@ import logging
from django.conf import settings from django.conf import settings
from django.http import Http404 from django.http import Http404
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.exceptions import AuthenticationFailed, ParseError from rest_framework.exceptions import AuthenticationFailed, ParseError
from rest_framework.generics import RetrieveAPIView, ListAPIView from rest_framework.generics import RetrieveAPIView, ListAPIView
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
...@@ -21,7 +22,6 @@ from courseware.access import has_access ...@@ -21,7 +22,6 @@ from courseware.access import has_access
from courseware.model_data import FieldDataCache from courseware.model_data import FieldDataCache
from courseware.module_render import get_module_for_descriptor from courseware.module_render import get_module_for_descriptor
from openedx.core.lib.api.view_utils import view_course_access, view_auth_classes from openedx.core.lib.api.view_utils import view_course_access, view_auth_classes
from openedx.core.lib.api.serializers import PaginationSerializer
from openedx.core.djangoapps.content.course_structures.api.v0 import api, errors from openedx.core.djangoapps.content.course_structures.api.v0 import api, errors
from student.roles import CourseInstructorRole, CourseStaffRole from student.roles import CourseInstructorRole, CourseStaffRole
from util.module_utils import get_dynamic_descriptor_children from util.module_utils import get_dynamic_descriptor_children
...@@ -157,9 +157,6 @@ class CourseList(CourseViewMixin, ListAPIView): ...@@ -157,9 +157,6 @@ class CourseList(CourseViewMixin, ListAPIView):
* end: The course end date. If course end date is not specified, the * end: The course end date. If course end date is not specified, the
value is null. value is null.
""" """
paginate_by = 10
paginate_by_param = 'page_size'
pagination_serializer_class = PaginationSerializer
serializer_class = serializers.CourseSerializer serializer_class = serializers.CourseSerializer
def get_queryset(self): def get_queryset(self):
......
...@@ -25,7 +25,6 @@ from xmodule.modulestore.django import modulestore ...@@ -25,7 +25,6 @@ from xmodule.modulestore.django import modulestore
from xmodule.modulestore.exceptions import ItemNotFoundError from xmodule.modulestore.exceptions import ItemNotFoundError
from .models import StudentModule from .models import StudentModule
from .module_render import get_module_for_descriptor from .module_render import get_module_for_descriptor
from submissions import api as sub_api # installed from the edx-submissions repository
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from openedx.core.djangoapps.signals.signals import GRADES_UPDATED from openedx.core.djangoapps.signals.signals import GRADES_UPDATED
...@@ -349,8 +348,13 @@ def _grade(student, request, course, keep_raw_scores, field_data_cache, scores_c ...@@ -349,8 +348,13 @@ def _grade(student, request, course, keep_raw_scores, field_data_cache, scores_c
# Dict of item_ids -> (earned, possible) point tuples. This *only* grabs # Dict of item_ids -> (earned, possible) point tuples. This *only* grabs
# scores that were registered with the submissions API, which for the moment # scores that were registered with the submissions API, which for the moment
# means only openassessment (edx-ora2) # means only openassessment (edx-ora2)
# We need to import this here to avoid a circular dependency of the form:
# XBlock --> submissions --> Django Rest Framework error strings -->
# Django translation --> ... --> courseware --> submissions
from submissions import api as sub_api # installed from the edx-submissions repository
submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id)) submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id))
max_scores_cache = MaxScoresCache.create_for_course(course) max_scores_cache = MaxScoresCache.create_for_course(course)
# For the moment, we have to get scorable_locations from field_data_cache # For the moment, we have to get scorable_locations from field_data_cache
# and not from scores_client, because scores_client is ignorant of things # and not from scores_client, because scores_client is ignorant of things
# in the submissions API. As a further refactoring step, submissions should # in the submissions API. As a further refactoring step, submissions should
...@@ -565,7 +569,12 @@ def _progress_summary(student, request, course, field_data_cache=None, scores_cl ...@@ -565,7 +569,12 @@ def _progress_summary(student, request, course, field_data_cache=None, scores_cl
course_module = getattr(course_module, '_x_module', course_module) course_module = getattr(course_module, '_x_module', course_module)
# We need to import this here to avoid a circular dependency of the form:
# XBlock --> submissions --> Django Rest Framework error strings -->
# Django translation --> ... --> courseware --> submissions
from submissions import api as sub_api # installed from the edx-submissions repository
submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id)) submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id))
max_scores_cache = MaxScoresCache.create_for_course(course) max_scores_cache = MaxScoresCache.create_for_course(course)
# For the moment, we have to get scorable_locations from field_data_cache # For the moment, we have to get scorable_locations from field_data_cache
# and not from scores_client, because scores_client is ignorant of things # and not from scores_client, because scores_client is ignorant of things
......
...@@ -560,7 +560,7 @@ def create_thread(request, thread_data): ...@@ -560,7 +560,7 @@ def create_thread(request, thread_data):
if not (serializer.is_valid() and actions_form.is_valid()): if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items())) raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
serializer.save() serializer.save()
cc_thread = serializer.object cc_thread = serializer.instance
thread_created.send(sender=None, user=user, post=cc_thread) thread_created.send(sender=None, user=user, post=cc_thread)
api_thread = serializer.data api_thread = serializer.data
_do_extra_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context) _do_extra_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context)
...@@ -606,7 +606,7 @@ def create_comment(request, comment_data): ...@@ -606,7 +606,7 @@ def create_comment(request, comment_data):
if not (serializer.is_valid() and actions_form.is_valid()): if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items())) raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
serializer.save() serializer.save()
cc_comment = serializer.object cc_comment = serializer.instance
comment_created.send(sender=None, user=request.user, post=cc_comment) comment_created.send(sender=None, user=request.user, post=cc_comment)
api_comment = serializer.data api_comment = serializer.data
_do_extra_actions(api_comment, cc_comment, comment_data.keys(), actions_form, context) _do_extra_actions(api_comment, cc_comment, comment_data.keys(), actions_form, context)
......
""" """
Discussion API pagination support Discussion API pagination support
""" """
from rest_framework.pagination import BasePaginationSerializer, NextPageField, PreviousPageField from rest_framework.utils.urls import replace_query_param
class _PaginationSerializer(BasePaginationSerializer):
"""
A pagination serializer without the count field, because the Comments
Service does not return result counts
"""
next = NextPageField(source="*")
previous = PreviousPageField(source="*")
class _Page(object): class _Page(object):
...@@ -52,7 +43,25 @@ def get_paginated_data(request, results, page_num, per_page): ...@@ -52,7 +43,25 @@ def get_paginated_data(request, results, page_num, per_page):
previous: The URL for the previous page previous: The URL for the previous page
results: The results on this page results: The results on this page
""" """
return _PaginationSerializer( # Note: Previous versions of this function used Django Rest Framework's
instance=_Page(results, page_num, per_page), # paginated serializer. With the upgrade to DRF 3.1, paginated serializers
context={"request": request} # have been removed. We *could* use DRF's paginator classes, but there are
).data # some slight differences between how DRF does pagination and how we're doing
# pagination here. (For example, we respond with a next_url param even if
# there is only one result on the current page.) To maintain backwards
# compatability, we simulate the behavior that DRF used to provide.
page = _Page(results, page_num, per_page)
next_url, previous_url = None, None
base_url = request.build_absolute_uri()
if page.has_next():
next_url = replace_query_param(base_url, "page", page.next_page_number())
if page.has_previous():
previous_url = replace_query_param(base_url, "page", page.previous_page_number())
return {
"next": next_url,
"previous": previous_url,
"results": results,
}
...@@ -1497,8 +1497,9 @@ class CreateThreadTest( ...@@ -1497,8 +1497,9 @@ class CreateThreadTest(
self.assertEqual(actual_post_data["group_id"], [str(cohort.id)]) self.assertEqual(actual_post_data["group_id"], [str(cohort.id)])
else: else:
self.assertNotIn("group_id", actual_post_data) self.assertNotIn("group_id", actual_post_data)
except ValidationError: except ValidationError as ex:
self.assertTrue(expected_error) if not expected_error:
self.fail("Unexpected validation error: {}".format(ex))
def test_following(self): def test_following(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id"})
...@@ -2239,7 +2240,7 @@ class UpdateThreadTest( ...@@ -2239,7 +2240,7 @@ class UpdateThreadTest(
update_thread(self.request, "test_thread", {"raw_body": ""}) update_thread(self.request, "test_thread", {"raw_body": ""})
self.assertEqual( self.assertEqual(
assertion.exception.message_dict, assertion.exception.message_dict,
{"raw_body": ["This field is required."]} {"raw_body": ["This field may not be blank."]}
) )
......
...@@ -523,9 +523,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -523,9 +523,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
data = self.minimal_data.copy() data = self.minimal_data.copy()
data.update({field: value for field in ["topic_id", "title", "raw_body"]}) data.update({field: value for field in ["topic_id", "title", "raw_body"]})
serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request)) serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request))
self.assertFalse(serializer.is_valid())
self.assertEqual( self.assertEqual(
serializer.errors, serializer.errors,
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]} {field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]}
) )
def test_create_type(self): def test_create_type(self):
...@@ -592,9 +593,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -592,9 +593,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
partial=True, partial=True,
context=get_context(self.course, self.request) context=get_context(self.course, self.request)
) )
self.assertFalse(serializer.is_valid())
self.assertEqual( self.assertEqual(
serializer.errors, serializer.errors,
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]} {field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]}
) )
def test_update_course_id(self): def test_update_course_id(self):
...@@ -604,6 +606,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -604,6 +606,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
partial=True, partial=True,
context=get_context(self.course, self.request) context=get_context(self.course, self.request)
) )
self.assertFalse(serializer.is_valid())
self.assertEqual( self.assertEqual(
serializer.errors, serializer.errors,
{"course_id": ["This field is not allowed in an update."]} {"course_id": ["This field is not allowed in an update."]}
...@@ -769,7 +772,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul ...@@ -769,7 +772,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["parent_id"] = None data["parent_id"] = None
serializer = CommentSerializer(data=data, context=context) serializer = CommentSerializer(data=data, context=context)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {"parent_id": ["Comment level is too deep."]}) self.assertEqual(serializer.errors, {"non_field_errors": ["Comment level is too deep."]})
def test_create_missing_field(self): def test_create_missing_field(self):
for field in self.minimal_data: for field in self.minimal_data:
...@@ -855,9 +858,10 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul ...@@ -855,9 +858,10 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
partial=True, partial=True,
context=get_context(self.course, self.request) context=get_context(self.course, self.request)
) )
self.assertFalse(serializer.is_valid())
self.assertEqual( self.assertEqual(
serializer.errors, serializer.errors,
{"raw_body": ["This field is required."]} {"raw_body": ["This field may not be blank."]}
) )
@ddt.data("thread_id", "parent_id") @ddt.data("thread_id", "parent_id")
...@@ -868,6 +872,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul ...@@ -868,6 +872,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
partial=True, partial=True,
context=get_context(self.course, self.request) context=get_context(self.course, self.request)
) )
self.assertFalse(serializer.is_valid())
self.assertEqual( self.assertEqual(
serializer.errors, serializer.errors,
{field: ["This field is not allowed in an update."]} {field: ["This field is not allowed in an update."]}
......
...@@ -571,7 +571,7 @@ class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTest ...@@ -571,7 +571,7 @@ class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTest
content_type="application/json" content_type="application/json"
) )
expected_response_data = { expected_response_data = {
"field_errors": {"title": {"developer_message": "This field is required."}} "field_errors": {"title": {"developer_message": "This field may not be blank."}}
} }
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content) response_data = json.loads(response.content)
...@@ -690,7 +690,7 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -690,7 +690,7 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
200, 200,
{ {
"results": expected_comments, "results": expected_comments,
"next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format( "next": "http://testserver/api/discussion/v1/comments/?page=2&thread_id={}".format(
self.thread_id self.thread_id
), ),
"previous": None, "previous": None,
...@@ -946,7 +946,7 @@ class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTes ...@@ -946,7 +946,7 @@ class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTes
content_type="application/json" content_type="application/json"
) )
expected_response_data = { expected_response_data = {
"field_errors": {"raw_body": {"developer_message": "This field is required."}} "field_errors": {"raw_body": {"developer_message": "This field may not be blank."}}
} }
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content) response_data = json.loads(response.content)
......
...@@ -3,7 +3,8 @@ Discussion API views ...@@ -3,7 +3,8 @@ Discussion API views
""" """
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
......
...@@ -32,8 +32,6 @@ from xmodule.modulestore.tests.factories import check_mongo_calls ...@@ -32,8 +32,6 @@ from xmodule.modulestore.tests.factories import check_mongo_calls
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore import ModuleStoreEnum
from teams.tests.factories import CourseTeamFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -1290,6 +1288,7 @@ class TeamsPermissionsTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSe ...@@ -1290,6 +1288,7 @@ class TeamsPermissionsTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSe
topic_id='topic_id', topic_id='topic_id',
discussion_topic_id=self.team_commentable_id discussion_topic_id=self.team_commentable_id
) )
self.team.add_user(self.student_in_team) self.team.add_user(self.student_in_team)
# Dummy commentable ID not linked to a team # Dummy commentable ID not linked to a team
......
...@@ -717,6 +717,7 @@ class InlineDiscussionContextTestCase(ModuleStoreTestCase): ...@@ -717,6 +717,7 @@ class InlineDiscussionContextTestCase(ModuleStoreTestCase):
topic_id='topic_id', topic_id='topic_id',
discussion_topic_id=self.discussion_topic_id discussion_topic_id=self.discussion_topic_id
) )
self.team.add_user(self.user) # pylint: disable=no-member self.team.add_user(self.user) # pylint: disable=no-member
def test_context_can_be_standalone(self, mock_request): def test_context_can_be_standalone(self, mock_request):
...@@ -1093,7 +1094,9 @@ class InlineDiscussionTestCase(ModuleStoreTestCase): ...@@ -1093,7 +1094,9 @@ class InlineDiscussionTestCase(ModuleStoreTestCase):
course_id=self.course.id, course_id=self.course.id,
discussion_topic_id=self.discussion1.discussion_id discussion_topic_id=self.discussion1.discussion_id
) )
team.add_user(self.student) # pylint: disable=no-member team.add_user(self.student) # pylint: disable=no-member
response = self.send_request(mock_request) response = self.send_request(mock_request)
self.assertEqual(mock_request.call_args[1]['params']['context'], ThreadContext.STANDALONE) self.assertEqual(mock_request.call_args[1]['params']['context'], ThreadContext.STANDALONE)
self.verify_response(response) self.verify_response(response)
......
...@@ -149,6 +149,8 @@ class NonCohortedTopicGroupIdTestMixin(GroupIdAssertionMixin): ...@@ -149,6 +149,8 @@ class NonCohortedTopicGroupIdTestMixin(GroupIdAssertionMixin):
def test_team_discussion_id_not_cohorted(self, mock_request): def test_team_discussion_id_not_cohorted(self, mock_request):
team = CourseTeamFactory(course_id=self.course.id) team = CourseTeamFactory(course_id=self.course.id)
team.add_user(self.student) # pylint: disable=no-member team.add_user(self.student) # pylint: disable=no-member
self.call_view(mock_request, team.discussion_topic_id, self.student, None) self.call_view(mock_request, team.discussion_topic_id, self.student, None)
self._assert_comments_service_called_without_group_id(mock_request) self._assert_comments_service_called_without_group_id(mock_request)
...@@ -31,7 +31,7 @@ class CoursesWithFriends(generics.ListAPIView): ...@@ -31,7 +31,7 @@ class CoursesWithFriends(generics.ListAPIView):
serializer_class = serializers.CoursesWithFriendsSerializer serializer_class = serializers.CoursesWithFriendsSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.GET, files=request.FILES) serializer = self.get_serializer(data=request.GET)
if not serializer.is_valid(): if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
...@@ -61,4 +61,5 @@ class CoursesWithFriends(generics.ListAPIView): ...@@ -61,4 +61,5 @@ class CoursesWithFriends(generics.ListAPIView):
and is_mobile_available_for_user(self.request.user, enrollment.course) and is_mobile_available_for_user(self.request.user, enrollment.course)
] ]
return Response(CourseEnrollmentSerializer(courses, context={'request': request}).data) serializer = CourseEnrollmentSerializer(courses, context={'request': request}, many=True)
return Response(serializer.data)
...@@ -40,7 +40,7 @@ class FriendsInCourse(generics.ListAPIView): ...@@ -40,7 +40,7 @@ class FriendsInCourse(generics.ListAPIView):
serializer_class = serializers.FriendsInCourseSerializer serializer_class = serializers.FriendsInCourseSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.GET, files=request.FILES) serializer = self.get_serializer(data=request.GET)
if not serializer.is_valid(): if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
......
...@@ -45,7 +45,7 @@ class Groups(generics.CreateAPIView, mixins.DestroyModelMixin): ...@@ -45,7 +45,7 @@ class Groups(generics.CreateAPIView, mixins.DestroyModelMixin):
serializer_class = serializers.GroupSerializer serializer_class = serializers.GroupSerializer
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES) serializer = self.get_serializer(data=request.DATA)
if not serializer.is_valid(): if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
try: try:
...@@ -106,12 +106,12 @@ class GroupsMembers(generics.CreateAPIView, mixins.DestroyModelMixin): ...@@ -106,12 +106,12 @@ class GroupsMembers(generics.CreateAPIView, mixins.DestroyModelMixin):
serializer_class = serializers.GroupsMembersSerializer serializer_class = serializers.GroupsMembersSerializer
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES) serializer = self.get_serializer(data=request.DATA)
if not serializer.is_valid(): if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
graph = facebook_graph_api() graph = facebook_graph_api()
url = settings.FACEBOOK_API_VERSION + '/' + kwargs['group_id'] + "/members" url = settings.FACEBOOK_API_VERSION + '/' + kwargs['group_id'] + "/members"
member_ids = serializer.object['member_ids'].split(',') member_ids = serializer.data['member_ids'].split(',')
response = {} response = {}
for member_id in member_ids: for member_id in member_ids:
try: try:
......
...@@ -8,4 +8,4 @@ class UserSharingSerializar(serializers.Serializer): ...@@ -8,4 +8,4 @@ class UserSharingSerializar(serializers.Serializer):
""" """
Serializes user social settings Serializes user social settings
""" """
share_with_facebook_friends = serializers.BooleanField(required=True, default=False) share_with_facebook_friends = serializers.BooleanField(required=True)
...@@ -39,9 +39,9 @@ class UserSharing(generics.ListCreateAPIView): ...@@ -39,9 +39,9 @@ class UserSharing(generics.ListCreateAPIView):
serializer_class = serializers.UserSharingSerializar serializer_class = serializers.UserSharingSerializar
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES) serializer = self.get_serializer(data=request.DATA)
if serializer.is_valid(): if serializer.is_valid():
value = serializer.object['share_with_facebook_friends'] value = serializer.data['share_with_facebook_friends']
set_user_preference(request.user, "share_with_facebook_friends", value) set_user_preference(request.user, "share_with_facebook_friends", value)
return self.get(request, *args, **kwargs) return self.get(request, *args, **kwargs)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
......
...@@ -40,7 +40,7 @@ def get_friends_from_facebook(serializer): ...@@ -40,7 +40,7 @@ def get_friends_from_facebook(serializer):
the error message. the error message.
""" """
try: try:
graph = facebook.GraphAPI(serializer.object['oauth_token']) graph = facebook.GraphAPI(serializer.data['oauth_token'])
friends = graph.request(settings.FACEBOOK_API_VERSION + "/me/friends") friends = graph.request(settings.FACEBOOK_API_VERSION + "/me/friends")
return get_pagination(friends) return get_pagination(friends)
except facebook.GraphAPIError, ex: except facebook.GraphAPIError, ex:
......
...@@ -15,7 +15,7 @@ from xmodule.course_module import DEFAULT_START_DATE ...@@ -15,7 +15,7 @@ from xmodule.course_module import DEFAULT_START_DATE
class CourseOverviewField(serializers.RelatedField): class CourseOverviewField(serializers.RelatedField):
"""Custom field to wrap a CourseDescriptor object. Read-only.""" """Custom field to wrap a CourseDescriptor object. Read-only."""
def to_native(self, course_overview): def to_representation(self, course_overview):
course_id = unicode(course_overview.id) course_id = unicode(course_overview.id)
request = self.context.get('request', None) request = self.context.get('request', None)
if request: if request:
...@@ -77,8 +77,8 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer): ...@@ -77,8 +77,8 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer):
""" """
Serializes CourseEnrollment models Serializes CourseEnrollment models
""" """
course = CourseOverviewField(source="course_overview") course = CourseOverviewField(source="course_overview", read_only=True)
certificate = serializers.SerializerMethodField('get_certificate') certificate = serializers.SerializerMethodField()
def get_certificate(self, model): def get_certificate(self, model):
"""Returns the information about the user's certificate in the course.""" """Returns the information about the user's certificate in the course."""
...@@ -100,7 +100,7 @@ class UserSerializer(serializers.HyperlinkedModelSerializer): ...@@ -100,7 +100,7 @@ class UserSerializer(serializers.HyperlinkedModelSerializer):
""" """
Serializes User models Serializes User models
""" """
name = serializers.Field(source='profile.name') name = serializers.ReadOnlyField(source='profile.name')
course_enrollments = serializers.HyperlinkedIdentityField( course_enrollments = serializers.HyperlinkedIdentityField(
view_name='courseenrollment-detail', view_name='courseenrollment-detail',
lookup_field='username' lookup_field='username'
......
...@@ -251,6 +251,14 @@ class UserCourseEnrollmentsList(generics.ListAPIView): ...@@ -251,6 +251,14 @@ class UserCourseEnrollmentsList(generics.ListAPIView):
serializer_class = CourseEnrollmentSerializer serializer_class = CourseEnrollmentSerializer
lookup_field = 'username' lookup_field = 'username'
# In Django Rest Framework v3, there is a default pagination
# class that transmutes the response data into a dictionary
# with pagination information. The original response data (a list)
# is stored in a "results" value of the dictionary.
# For backwards compatibility with the existing API, we disable
# the default behavior by setting the pagination_class to None.
pagination_class = None
def get_queryset(self): def get_queryset(self):
enrollments = self.queryset.filter( enrollments = self.queryset.filter(
user__username=self.kwargs['username'], user__username=self.kwargs['username'],
......
...@@ -23,9 +23,9 @@ class NotifierUserSerializer(serializers.ModelSerializer): ...@@ -23,9 +23,9 @@ class NotifierUserSerializer(serializers.ModelSerializer):
* course_groups * course_groups
* roles__permissions * roles__permissions
""" """
name = serializers.SerializerMethodField("get_name") name = serializers.SerializerMethodField()
preferences = serializers.SerializerMethodField("get_preferences") preferences = serializers.SerializerMethodField()
course_info = serializers.SerializerMethodField("get_course_info") course_info = serializers.SerializerMethodField()
def get_name(self, user): def get_name(self, user):
return user.profile.name return user.profile.name
......
from django.contrib.auth.models import User from django.contrib.auth.models import User
from rest_framework.viewsets import ReadOnlyModelViewSet from rest_framework.viewsets import ReadOnlyModelViewSet
from rest_framework.response import Response
from rest_framework import pagination
from notification_prefs import NOTIFICATION_PREF_KEY from notification_prefs import NOTIFICATION_PREF_KEY
from notifier_api.serializers import NotifierUserSerializer from notifier_api.serializers import NotifierUserSerializer
from openedx.core.lib.api.permissions import ApiKeyHeaderPermission from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
class NotifierPaginator(pagination.PageNumberPagination):
"""
Paginator for the notifier API.
"""
page_size = 10
page_size_query_param = "page_size"
def get_paginated_response(self, data):
"""
Construct a response with pagination information.
"""
return Response({
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'count': self.page.paginator.count,
'results': data
})
class NotifierUsersViewSet(ReadOnlyModelViewSet): class NotifierUsersViewSet(ReadOnlyModelViewSet):
""" """
An endpoint that the notifier can use to retrieve users who have enabled An endpoint that the notifier can use to retrieve users who have enabled
...@@ -14,8 +35,7 @@ class NotifierUsersViewSet(ReadOnlyModelViewSet): ...@@ -14,8 +35,7 @@ class NotifierUsersViewSet(ReadOnlyModelViewSet):
""" """
permission_classes = (ApiKeyHeaderPermission,) permission_classes = (ApiKeyHeaderPermission,)
serializer_class = NotifierUserSerializer serializer_class = NotifierUserSerializer
paginate_by = 10 pagination_class = NotifierPaginator
paginate_by_param = "page_size"
# See NotifierUserSerializer for notes about related tables # See NotifierUserSerializer for notes about related tables
queryset = User.objects.filter( queryset = User.objects.filter(
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
import pytz import pytz
from datetime import datetime
from model_utils import FieldTracker from model_utils import FieldTracker
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
......
...@@ -10,6 +10,7 @@ from django.utils import translation ...@@ -10,6 +10,7 @@ from django.utils import translation
from functools import wraps from functools import wraps
from search.search_engine_base import SearchEngine from search.search_engine_base import SearchEngine
from request_cache import get_request_or_stub
from .errors import ElasticSearchConnectionError from .errors import ElasticSearchConnectionError
from .serializers import CourseTeamSerializer, CourseTeam from .serializers import CourseTeamSerializer, CourseTeam
...@@ -47,7 +48,15 @@ class CourseTeamIndexer(object): ...@@ -47,7 +48,15 @@ class CourseTeamIndexer(object):
Returns serialized object with additional search fields. Returns serialized object with additional search fields.
""" """
serialized_course_team = CourseTeamSerializer(self.course_team).data # Django Rest Framework v3.1 requires that we pass the request to the serializer
# so it can construct hyperlinks. To avoid changing the interface of this object,
# we retrieve the request from the request cache.
context = {
"request": get_request_or_stub()
}
serialized_course_team = CourseTeamSerializer(self.course_team, context=context).data
# Save the primary key so we can load the full objects easily after we search # Save the primary key so we can load the full objects easily after we search
serialized_course_team['pk'] = self.course_team.pk serialized_course_team['pk'] = self.course_team.pk
# Don't save the membership relations in elasticsearch # Don't save the membership relations in elasticsearch
......
...@@ -4,15 +4,44 @@ from django.contrib.auth.models import User ...@@ -4,15 +4,44 @@ from django.contrib.auth.models import User
from django.db.models import Count from django.db.models import Count
from django.conf import settings from django.conf import settings
from django_countries import countries
from rest_framework import serializers from rest_framework import serializers
from openedx.core.lib.api.serializers import CollapsedReferenceSerializer, PaginationSerializer from openedx.core.lib.api.serializers import CollapsedReferenceSerializer
from openedx.core.lib.api.fields import ExpandableField from openedx.core.lib.api.fields import ExpandableField
from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer
from .models import CourseTeam, CourseTeamMembership from .models import CourseTeam, CourseTeamMembership
class CountryField(serializers.Field):
"""
Field to serialize a country code.
"""
COUNTRY_CODES = dict(countries).keys()
def to_representation(self, obj):
"""
Represent the country as a 2-character unicode identifier.
"""
return unicode(obj)
def to_internal_value(self, data):
"""
Check that the code is a valid country code.
We leave the data in its original format so that the Django model's
CountryField can convert it to the internal representation used
by the django-countries library.
"""
if data and data not in self.COUNTRY_CODES:
raise serializers.ValidationError(
u"{code} is not a valid country code".format(code=data)
)
return data
class UserMembershipSerializer(serializers.ModelSerializer): class UserMembershipSerializer(serializers.ModelSerializer):
"""Serializes CourseTeamMemberships with only user and date_joined """Serializes CourseTeamMemberships with only user and date_joined
...@@ -43,6 +72,7 @@ class CourseTeamSerializer(serializers.ModelSerializer): ...@@ -43,6 +72,7 @@ class CourseTeamSerializer(serializers.ModelSerializer):
"""Serializes a CourseTeam with membership information.""" """Serializes a CourseTeam with membership information."""
id = serializers.CharField(source='team_id', read_only=True) # pylint: disable=invalid-name id = serializers.CharField(source='team_id', read_only=True) # pylint: disable=invalid-name
membership = UserMembershipSerializer(many=True, read_only=True) membership = UserMembershipSerializer(many=True, read_only=True)
country = CountryField()
class Meta(object): class Meta(object):
"""Defines meta information for the ModelSerializer.""" """Defines meta information for the ModelSerializer."""
...@@ -66,6 +96,8 @@ class CourseTeamSerializer(serializers.ModelSerializer): ...@@ -66,6 +96,8 @@ class CourseTeamSerializer(serializers.ModelSerializer):
class CourseTeamCreationSerializer(serializers.ModelSerializer): class CourseTeamCreationSerializer(serializers.ModelSerializer):
"""Deserializes a CourseTeam for creation.""" """Deserializes a CourseTeam for creation."""
country = CountryField(required=False)
class Meta(object): class Meta(object):
"""Defines meta information for the ModelSerializer.""" """Defines meta information for the ModelSerializer."""
model = CourseTeam model = CourseTeam
...@@ -78,16 +110,17 @@ class CourseTeamCreationSerializer(serializers.ModelSerializer): ...@@ -78,16 +110,17 @@ class CourseTeamCreationSerializer(serializers.ModelSerializer):
"language", "language",
) )
def restore_object(self, attrs, instance=None): def create(self, validated_data):
"""Restores a CourseTeam instance from the given attrs.""" team = CourseTeam.create(
return CourseTeam.create( name=validated_data.get("name", ''),
name=attrs.get("name", ''), course_id=validated_data.get("course_id"),
course_id=attrs.get("course_id"), description=validated_data.get("description", ''),
description=attrs.get("description", ''), topic_id=validated_data.get("topic_id", ''),
topic_id=attrs.get("topic_id", ''), country=validated_data.get("country", ''),
country=attrs.get("country", ''), language=validated_data.get("language", ''),
language=attrs.get("language", ''),
) )
team.save()
return team
class CourseTeamSerializerWithoutMembership(CourseTeamSerializer): class CourseTeamSerializerWithoutMembership(CourseTeamSerializer):
...@@ -134,13 +167,6 @@ class MembershipSerializer(serializers.ModelSerializer): ...@@ -134,13 +167,6 @@ class MembershipSerializer(serializers.ModelSerializer):
read_only_fields = ("date_joined", "last_activity_at") read_only_fields = ("date_joined", "last_activity_at")
class PaginatedMembershipSerializer(PaginationSerializer):
"""Serializes team memberships with support for pagination."""
class Meta(object):
"""Defines meta information for the PaginatedMembershipSerializer."""
object_serializer_class = MembershipSerializer
class BaseTopicSerializer(serializers.Serializer): class BaseTopicSerializer(serializers.Serializer):
"""Serializes a topic without team_count.""" """Serializes a topic without team_count."""
description = serializers.CharField() description = serializers.CharField()
...@@ -155,7 +181,7 @@ class TopicSerializer(BaseTopicSerializer): ...@@ -155,7 +181,7 @@ class TopicSerializer(BaseTopicSerializer):
model to get the count. Requires that `context` is provided with a valid course_id model to get the count. Requires that `context` is provided with a valid course_id
in order to filter teams within the course. in order to filter teams within the course.
""" """
team_count = serializers.SerializerMethodField('get_team_count') team_count = serializers.SerializerMethodField()
def get_team_count(self, topic): def get_team_count(self, topic):
"""Get the number of teams associated with this topic""" """Get the number of teams associated with this topic"""
...@@ -166,31 +192,25 @@ class TopicSerializer(BaseTopicSerializer): ...@@ -166,31 +192,25 @@ class TopicSerializer(BaseTopicSerializer):
return CourseTeam.objects.filter(course_id=self.context['course_id'], topic_id=topic['id']).count() return CourseTeam.objects.filter(course_id=self.context['course_id'], topic_id=topic['id']).count()
class PaginatedTopicSerializer(PaginationSerializer): class BulkTeamCountTopicListSerializer(serializers.ListSerializer): # pylint: disable=abstract-method
""" """
Serializes a set of topics, adding the team_count field to each topic individually, if team_count List serializer for efficiently serializing a set of topics.
is not already present in the topic data. Requires that `context` is provided with a valid course_id in
order to filter teams within the course.
""" """
class Meta(object):
"""Defines meta information for the PaginatedTopicSerializer."""
object_serializer_class = TopicSerializer
def to_representation(self, obj):
"""Adds team_count to each topic. """
data = super(BulkTeamCountTopicListSerializer, self).to_representation(obj)
add_team_count(data, self.context["course_id"])
return data
class BulkTeamCountPaginatedTopicSerializer(PaginationSerializer):
class BulkTeamCountTopicSerializer(BaseTopicSerializer): # pylint: disable=abstract-method
""" """
Serializes a set of topics, adding the team_count field to each topic as a bulk operation per page Serializes a set of topics, adding the team_count field to each topic as a bulk operation.
(only on the page being returned). Requires that `context` is provided with a valid course_id in Requires that `context` is provided with a valid course_id in order to filter teams within the course.
order to filter teams within the course.
""" """
class Meta(object): class Meta: # pylint: disable=missing-docstring,old-style-class
"""Defines meta information for the BulkTeamCountPaginatedTopicSerializer.""" list_serializer_class = BulkTeamCountTopicListSerializer
object_serializer_class = BaseTopicSerializer
def __init__(self, *args, **kwargs):
"""Adds team_count to each topic on the current page."""
super(BulkTeamCountPaginatedTopicSerializer, self).__init__(*args, **kwargs)
add_team_count(self.data['results'], self.context['course_id'])
def add_team_count(topics, course_id): def add_team_count(topics, course_id):
......
...@@ -34,3 +34,10 @@ class CourseTeamMembershipFactory(DjangoModelFactory): ...@@ -34,3 +34,10 @@ class CourseTeamMembershipFactory(DjangoModelFactory):
class Meta(object): # pylint: disable=missing-docstring class Meta(object): # pylint: disable=missing-docstring
model = CourseTeamMembership model = CourseTeamMembership
last_activity_at = LAST_ACTIVITY_AT last_activity_at = LAST_ACTIVITY_AT
@classmethod
def _create(cls, model_class, *args, **kwargs):
"""Create the team membership. """
obj = model_class(*args, **kwargs)
obj.save()
return obj
...@@ -11,12 +11,9 @@ from xmodule.modulestore.tests.factories import CourseFactory ...@@ -11,12 +11,9 @@ from xmodule.modulestore.tests.factories import CourseFactory
from lms.djangoapps.teams.tests.factories import CourseTeamFactory, CourseTeamMembershipFactory from lms.djangoapps.teams.tests.factories import CourseTeamFactory, CourseTeamMembershipFactory
from lms.djangoapps.teams.serializers import ( from lms.djangoapps.teams.serializers import (
BaseTopicSerializer, BulkTeamCountTopicSerializer,
PaginatedTopicSerializer,
BulkTeamCountPaginatedTopicSerializer,
TopicSerializer, TopicSerializer,
MembershipSerializer, MembershipSerializer,
add_team_count
) )
...@@ -73,21 +70,6 @@ class MembershipSerializerTestCase(SerializerTestCase): ...@@ -73,21 +70,6 @@ class MembershipSerializerTestCase(SerializerTestCase):
self.assertNotIn('membership', data['team']) self.assertNotIn('membership', data['team'])
class BaseTopicSerializerTestCase(SerializerTestCase):
"""
Tests for the `BaseTopicSerializer`, which should not serialize team count
data.
"""
def test_team_count_not_included(self):
"""Verifies that the `BaseTopicSerializer` does not include team count"""
with self.assertNumQueries(0):
serializer = BaseTopicSerializer(self.course.teams_topics[0])
self.assertEqual(
serializer.data,
{u'name': u'Tøpic', u'description': u'The bést topic!', u'id': u'0'}
)
class TopicSerializerTestCase(SerializerTestCase): class TopicSerializerTestCase(SerializerTestCase):
""" """
Tests for the `TopicSerializer`, which should serialize team count data for Tests for the `TopicSerializer`, which should serialize team count data for
...@@ -137,7 +119,7 @@ class TopicSerializerTestCase(SerializerTestCase): ...@@ -137,7 +119,7 @@ class TopicSerializerTestCase(SerializerTestCase):
) )
class BasePaginatedTopicSerializerTestCase(SerializerTestCase): class BaseTopicSerializerTestCase(SerializerTestCase):
""" """
Base class for testing the two paginated topic serializers. Base class for testing the two paginated topic serializers.
""" """
...@@ -191,13 +173,15 @@ class BasePaginatedTopicSerializerTestCase(SerializerTestCase): ...@@ -191,13 +173,15 @@ class BasePaginatedTopicSerializerTestCase(SerializerTestCase):
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0) self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase): class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
""" """
Tests for the `BulkTeamCountPaginatedTopicSerializer`, which should serialize team_count Tests for the `BulkTeamCountTopicSerializer`, which should serialize team_count
data for many topics with constant time SQL queries. data for many topics with constant time SQL queries.
""" """
__test__ = True __test__ = True
serializer = BulkTeamCountPaginatedTopicSerializer serializer = BulkTeamCountTopicSerializer
NUM_TOPICS = 6
def test_topics_with_no_team_counts(self): def test_topics_with_no_team_counts(self):
""" """
...@@ -222,13 +206,13 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer ...@@ -222,13 +206,13 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer
one SQL query. one SQL query.
""" """
teams_per_topic = 10 teams_per_topic = 10
topics = self.setup_topics(num_topics=self.PAGE_SIZE + 1, teams_per_topic=teams_per_topic) topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
self.assert_serializer_output(topics[:self.PAGE_SIZE], num_teams_per_topic=teams_per_topic, num_queries=1) self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1)
def test_scoped_within_course(self): def test_scoped_within_course(self):
"""Verify that team counts are scoped within a course.""" """Verify that team counts are scoped within a course."""
teams_per_topic = 10 teams_per_topic = 10
first_course_topics = self.setup_topics(num_topics=self.PAGE_SIZE, teams_per_topic=teams_per_topic) first_course_topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
duplicate_topic = first_course_topics[0] duplicate_topic = first_course_topics[0]
second_course = CourseFactory.create( second_course = CourseFactory.create(
teams_configuration={ teams_configuration={
...@@ -239,27 +223,44 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer ...@@ -239,27 +223,44 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer
CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id']) CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id'])
self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1) self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1)
def _merge_dicts(self, first, second):
"""Convenience method to merge two dicts in a single expression"""
result = first.copy()
result.update(second)
return result
class PaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase): def setup_topics(self, num_topics=5, teams_per_topic=0):
""" """
Tests for the `PaginatedTopicSerializer`, which will add team_count information per topic if not present. Helper method to set up topics on the course. Returns a list of
""" created topics.
__test__ = True """
serializer = PaginatedTopicSerializer self.course.teams_configuration['topics'] = []
topics = [
{u'name': u'Tøpic {}'.format(i), u'description': u'The bést topic! {}'.format(i), u'id': unicode(i)}
for i in xrange(num_topics)
]
for i in xrange(num_topics):
topic_id = unicode(i)
self.course.teams_configuration['topics'].append(topics[i])
for _ in xrange(teams_per_topic):
CourseTeamFactory.create(course_id=self.course.id, topic_id=topic_id)
return topics
def test_topics_with_team_counts(self): def assert_serializer_output(self, topics, num_teams_per_topic, num_queries):
""" """
Verify that we serialize topics with team_count, making one SQL query per topic. Verify that the serializer produced the expected topics.
""" """
teams_per_topic = 2 with self.assertNumQueries(num_queries):
topics = self.setup_topics(teams_per_topic=teams_per_topic) serializer = self.serializer(topics, context={'course_id': self.course.id}, many=True)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=5) self.assertEqual(
serializer.data,
[self._merge_dicts(topic, {u'team_count': num_teams_per_topic}) for topic in topics]
)
def test_topics_with_team_counts_prepopulated(self): def test_no_topics(self):
""" """
Verify that if team_count is pre-populated, there are no additional SQL queries. Verify that we return no results and make no SQL queries for a page
with no topics.
""" """
teams_per_topic = 8 self.course.teams_configuration['topics'] = []
topics = self.setup_topics(teams_per_topic=teams_per_topic) self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
add_team_count(topics, self.course.id)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=0)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Tests for the teams API at the HTTP request level.""" """Tests for the teams API at the HTTP request level."""
import json import json
import pytz
from datetime import datetime from datetime import datetime
import pytz
from dateutil import parser from dateutil import parser
import ddt import ddt
from elasticsearch.exceptions import ConnectionError from elasticsearch.exceptions import ConnectionError
from mock import patch from mock import patch
from search.search_engine_base import SearchEngine from search.search_engine_base import SearchEngine
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.conf import settings from django.conf import settings
from django.db.models.signals import post_save from django.db.models.signals import post_save
from django.utils import translation from django.utils import translation
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from rest_framework.test import APITestCase, APIClient from rest_framework.test import APITestCase, APIClient
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
from courseware.tests.factories import StaffFactory from courseware.tests.factories import StaffFactory
from common.test.utils import skip_signal from common.test.utils import skip_signal
from student.tests.factories import UserFactory, AdminFactory, CourseEnrollmentFactory from student.tests.factories import UserFactory, AdminFactory, CourseEnrollmentFactory
from student.models import CourseEnrollment from student.models import CourseEnrollment
from util.testing import EventTestMixin from util.testing import EventTestMixin
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
from .factories import CourseTeamFactory, LAST_ACTIVITY_AT from .factories import CourseTeamFactory, LAST_ACTIVITY_AT
from ..models import CourseTeamMembership from ..models import CourseTeamMembership
from ..search_indexes import CourseTeamIndexer, CourseTeam, course_team_post_save_callback from ..search_indexes import CourseTeamIndexer, CourseTeam, course_team_post_save_callback
from django_comment_common.models import Role, FORUM_ROLE_COMMUNITY_TA from django_comment_common.models import Role, FORUM_ROLE_COMMUNITY_TA
from django_comment_common.utils import seed_permissions_roles from django_comment_common.utils import seed_permissions_roles
...@@ -36,11 +35,23 @@ class TestDashboard(SharedModuleStoreTestCase): ...@@ -36,11 +35,23 @@ class TestDashboard(SharedModuleStoreTestCase):
"""Tests for the Teams dashboard.""" """Tests for the Teams dashboard."""
test_password = "test" test_password = "test"
NUM_TOPICS = 10
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(TestDashboard, cls).setUpClass() super(TestDashboard, cls).setUpClass()
cls.course = CourseFactory.create( cls.course = CourseFactory.create(
teams_configuration={"max_team_size": 10, "topics": [{"name": "foo", "id": 0, "description": "test topic"}]} teams_configuration={
"max_team_size": 10,
"topics": [
{
"name": "Topic {}".format(topic_id),
"id": topic_id,
"description": "Description for topic {}".format(topic_id)
}
for topic_id in range(cls.NUM_TOPICS)
]
}
) )
def setUp(self): def setUp(self):
...@@ -97,6 +108,30 @@ class TestDashboard(SharedModuleStoreTestCase): ...@@ -97,6 +108,30 @@ class TestDashboard(SharedModuleStoreTestCase):
response = self.client.get(teams_url) response = self.client.get(teams_url)
self.assertEqual(404, response.status_code) self.assertEqual(404, response.status_code)
def test_query_counts(self):
# Enroll in the course and log in
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.client.login(username=self.user.username, password=self.test_password)
# Check the query count on the dashboard With no teams
with self.assertNumQueries(15):
self.client.get(self.teams_url)
# Create some teams
for topic_id in range(self.NUM_TOPICS):
team = CourseTeamFactory.create(
name=u"Team for topic {}".format(topic_id),
course_id=self.course.id,
topic_id=topic_id,
)
# Add the user to the last team
team.add_user(self.user)
# Check the query count on the dashboard again
with self.assertNumQueries(19):
self.client.get(self.teams_url)
def test_bad_course_id(self): def test_bad_course_id(self):
""" """
Verifies expected behavior when course_id does not reference an existing course or is invalid. Verifies expected behavior when course_id does not reference an existing course or is invalid.
...@@ -252,6 +287,9 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase): ...@@ -252,6 +287,9 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
self.users[user], course.id, check_access=True self.users[user], course.id, check_access=True
) )
# Django Rest Framework v3 requires us to pass a request to serializers
# that have URL fields. Since we're invoking this code outside the context
# of a request, we need to simulate that there's a request.
self.solar_team.add_user(self.users['student_enrolled']) self.solar_team.add_user(self.users['student_enrolled'])
self.nuclear_team.add_user(self.users['student_enrolled_both_courses_other_team']) self.nuclear_team.add_user(self.users['student_enrolled_both_courses_other_team'])
self.another_team.add_user(self.users['student_enrolled_both_courses_other_team']) self.another_team.add_user(self.users['student_enrolled_both_courses_other_team'])
...@@ -311,7 +349,17 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase): ...@@ -311,7 +349,17 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
response = func(url, data=data, content_type=content_type) response = func(url, data=data, content_type=content_type)
else: else:
response = func(url, data=data) response = func(url, data=data)
self.assertEqual(expected_status, response.status_code)
self.assertEqual(
expected_status,
response.status_code,
msg="Expected status {expected} but got {actual}: {content}".format(
expected=expected_status,
actual=response.status_code,
content=response.content,
)
)
if expected_status == 200: if expected_status == 200:
return json.loads(response.content) return json.loads(response.content)
else: else:
......
...@@ -1866,6 +1866,12 @@ INSTALLED_APPS = ( ...@@ -1866,6 +1866,12 @@ INSTALLED_APPS = (
'provider.oauth2', 'provider.oauth2',
'oauth2_provider', 'oauth2_provider',
# We don't use this directly (since we use OAuth2), but we need to install it anyway.
# When a user is deleted, Django queries all tables with a FK to the auth_user table,
# and since django-rest-framework-oauth imports this, it will try to access tables
# defined by oauth_provider. If those tables don't exist, an error can occur.
'oauth_provider',
'auth_exchange', 'auth_exchange',
# For the wiki # For the wiki
...@@ -1981,6 +1987,14 @@ INSTALLED_APPS = ( ...@@ -1981,6 +1987,14 @@ INSTALLED_APPS = (
CSRF_COOKIE_AGE = 60 * 60 * 24 * 7 * 52 CSRF_COOKIE_AGE = 60 * 60 * 24 * 7 * 52
######################### Django Rest Framework ########################
REST_FRAMEWORK = {
'DEFAULT_PAGINATION_CLASS': 'openedx.core.lib.api.paginators.DefaultPagination',
'PAGE_SIZE': 10,
}
######################### MARKETING SITE ############################### ######################### MARKETING SITE ###############################
EDXMKTG_LOGGED_IN_COOKIE_NAME = 'edxloggedin' EDXMKTG_LOGGED_IN_COOKIE_NAME = 'edxloggedin'
EDXMKTG_USER_INFO_COOKIE_NAME = 'edx-user-info' EDXMKTG_USER_INFO_COOKIE_NAME = 'edx-user-info'
......
...@@ -9,16 +9,12 @@ from django.conf import settings ...@@ -9,16 +9,12 @@ from django.conf import settings
# Force settings to run so that the python path is modified # Force settings to run so that the python path is modified
settings.INSTALLED_APPS # pylint: disable=pointless-statement settings.INSTALLED_APPS # pylint: disable=pointless-statement
from instructor.services import InstructorService
from openedx.core.lib.django_startup import autostartup from openedx.core.lib.django_startup import autostartup
import edxmako import edxmako
import logging import logging
from monkey_patch import django_utils_translation from monkey_patch import django_utils_translation
import analytics import analytics
from edx_proctoring.runtime import set_runtime_service
from openedx.core.djangoapps.credit.services import CreditService
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -51,6 +47,11 @@ def run(): ...@@ -51,6 +47,11 @@ def run():
# right now edx_proctoring is dependent on the openedx.core.djangoapps.credit # right now edx_proctoring is dependent on the openedx.core.djangoapps.credit
# as well as the instructor dashboard (for deleting student attempts) # as well as the instructor dashboard (for deleting student attempts)
if settings.FEATURES.get('ENABLE_PROCTORED_EXAMS'): if settings.FEATURES.get('ENABLE_PROCTORED_EXAMS'):
# Import these here to avoid circular dependencies of the form:
# edx-platform app --> DRF --> django translation --> edx-platform app
from edx_proctoring.runtime import set_runtime_service
from instructor.services import InstructorService
from openedx.core.djangoapps.credit.services import CreditService
set_runtime_service('credit', CreditService()) set_runtime_service('credit', CreditService())
set_runtime_service('instructor', InstructorService()) set_runtime_service('instructor', InstructorService())
......
...@@ -124,4 +124,4 @@ def course_grading_policy(course_key): ...@@ -124,4 +124,4 @@ def course_grading_policy(course_key):
final grade. final grade.
""" """
course = _retrieve_course(course_key) course = _retrieve_course(course_key)
return GradingPolicySerializer(course.raw_grader).data return GradingPolicySerializer(course.raw_grader, many=True).data
""" """
API Serializers API Serializers
""" """
from collections import defaultdict
from rest_framework import serializers from rest_framework import serializers
...@@ -11,23 +13,58 @@ class GradingPolicySerializer(serializers.Serializer): ...@@ -11,23 +13,58 @@ class GradingPolicySerializer(serializers.Serializer):
dropped = serializers.IntegerField(source='drop_count') dropped = serializers.IntegerField(source='drop_count')
weight = serializers.FloatField() weight = serializers.FloatField()
def to_representation(self, obj):
"""
Return a representation of the grading policy.
"""
# Backwards compatibility with the behavior of DRF v2.
# When the grader dictionary was missing keys, DRF v2 would default to None;
# DRF v3 unhelpfully raises an exception.
return dict(
super(GradingPolicySerializer, self).to_representation(
defaultdict(lambda: None, obj)
)
)
# pylint: disable=invalid-name # pylint: disable=invalid-name
class BlockSerializer(serializers.Serializer): class BlockSerializer(serializers.Serializer):
""" Serializer for course structure block. """ """ Serializer for course structure block. """
id = serializers.CharField(source='usage_key') id = serializers.CharField(source='usage_key')
type = serializers.CharField(source='block_type') type = serializers.CharField(source='block_type')
parent = serializers.CharField(source='parent') parent = serializers.CharField(required=False)
display_name = serializers.CharField() display_name = serializers.CharField()
graded = serializers.BooleanField(default=False) graded = serializers.BooleanField(default=False)
format = serializers.CharField() format = serializers.CharField()
children = serializers.CharField() children = serializers.CharField()
def to_representation(self, obj):
"""
Return a representation of the block.
NOTE: this method maintains backwards compatibility with the behavior
of Django Rest Framework v2.
"""
data = super(BlockSerializer, self).to_representation(obj)
# Backwards compatibility with the behavior of DRF v2
# Include a NULL value for "parent" in the representation
# (instead of excluding the key entirely)
if obj.get("parent") is None:
data["parent"] = None
# Backwards compatibility with the behavior of DRF v2
# Leave the children list as a list instead of serializing
# it to a string.
data["children"] = obj["children"]
return data
class CourseStructureSerializer(serializers.Serializer): class CourseStructureSerializer(serializers.Serializer):
""" Serializer for course structure. """ """ Serializer for course structure. """
root = serializers.CharField(source='root') root = serializers.CharField()
blocks = serializers.SerializerMethodField('get_blocks') blocks = serializers.SerializerMethodField()
def get_blocks(self, structure): def get_blocks(self, structure):
""" Serialize the individual blocks. """ """ Serialize the individual blocks. """
......
...@@ -2,12 +2,33 @@ ...@@ -2,12 +2,33 @@
from rest_framework import serializers from rest_framework import serializers
from opaque_keys.edx.keys import CourseKey
from opaque_keys import InvalidKeyError
from openedx.core.djangoapps.credit.models import CreditCourse from openedx.core.djangoapps.credit.models import CreditCourse
class CourseKeyField(serializers.Field):
"""
Serializer field for a model CourseKey field.
"""
def to_representation(self, data):
"""Convert a course key to unicode. """
return unicode(data)
def to_internal_value(self, data):
"""Convert unicode to a course key. """
try:
return CourseKey.from_string(data)
except InvalidKeyError as ex:
raise serializers.ValidationError("Invalid course key: {msg}".format(msg=ex.msg))
class CreditCourseSerializer(serializers.ModelSerializer): class CreditCourseSerializer(serializers.ModelSerializer):
""" CreditCourse Serializer """ """ CreditCourse Serializer """
course_key = CourseKeyField()
class Meta(object): # pylint: disable=missing-docstring class Meta(object): # pylint: disable=missing-docstring
model = CreditCourse model = CreditCourse
exclude = ('id',) exclude = ('id',)
...@@ -393,10 +393,7 @@ class CreditCourseViewSetTests(TestCase): ...@@ -393,10 +393,7 @@ class CreditCourseViewSetTests(TestCase):
# POSTs without a CSRF token should fail. # POSTs without a CSRF token should fail.
response = client.post(self.path, data=json.dumps(data), content_type=JSON) response = client.post(self.path, data=json.dumps(data), content_type=JSON)
self.assertEqual(response.status_code, 403)
# NOTE (CCB): Ordinarily we would expect a 403; however, since the CSRF validation and session authentication
# fail, DRF considers the request to be unauthenticated.
self.assertEqual(response.status_code, 401)
self.assertIn('CSRF', response.content) self.assertIn('CSRF', response.content)
# Retrieve a CSRF token # Retrieve a CSRF token
......
...@@ -18,8 +18,9 @@ from django.views.decorators.http import require_POST, require_GET ...@@ -18,8 +18,9 @@ from django.views.decorators.http import require_POST, require_GET
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
import pytz import pytz
from rest_framework import viewsets, mixins, permissions, authentication from rest_framework import viewsets, mixins, permissions
from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from util.json_request import JsonResponse from util.json_request import JsonResponse
from util.date_utils import from_timestamp from util.date_utils import from_timestamp
from openedx.core.djangoapps.credit import api from openedx.core.djangoapps.credit import api
...@@ -377,17 +378,28 @@ class CreditCourseViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, view ...@@ -377,17 +378,28 @@ class CreditCourseViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, view
lookup_value_regex = settings.COURSE_KEY_REGEX lookup_value_regex = settings.COURSE_KEY_REGEX
queryset = CreditCourse.objects.all() queryset = CreditCourse.objects.all()
serializer_class = CreditCourseSerializer serializer_class = CreditCourseSerializer
authentication_classes = (authentication.OAuth2Authentication, authentication.SessionAuthentication,) authentication_classes = (OAuth2Authentication, SessionAuthentication,)
permission_classes = (permissions.IsAuthenticated, permissions.IsAdminUser) permission_classes = (permissions.IsAuthenticated, permissions.IsAdminUser)
# In Django Rest Framework v3, there is a default pagination
# class that transmutes the response data into a dictionary
# with pagination information. The original response data (a list)
# is stored in a "results" value of the dictionary.
# For backwards compatibility with the existing API, we disable
# the default behavior by setting the pagination_class to None.
pagination_class = None
# This CSRF exemption only applies when authenticating without SessionAuthentication. # This CSRF exemption only applies when authenticating without SessionAuthentication.
# SessionAuthentication will enforce CSRF protection. # SessionAuthentication will enforce CSRF protection.
@method_decorator(csrf_exempt) @method_decorator(csrf_exempt)
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
# Convert the course ID/key from a string to an actual CourseKey object. return super(CreditCourseViewSet, self).dispatch(request, *args, **kwargs)
course_id = kwargs.get(self.lookup_field, None)
if course_id: def get_object(self):
kwargs[self.lookup_field] = CourseKey.from_string(course_id) # Convert the serialized course key into a CourseKey instance
# so we can look up the object.
course_key = self.kwargs.get(self.lookup_field)
if course_key is not None:
self.kwargs[self.lookup_field] = CourseKey.from_string(course_key)
return super(CreditCourseViewSet, self).dispatch(request, *args, **kwargs) return super(CreditCourseViewSet, self).get_object()
...@@ -30,6 +30,40 @@ TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC) ...@@ -30,6 +30,40 @@ TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC)
TEST_UPLOAD_DT2 = datetime.datetime(2003, 1, 9, 15, 43, 01, tzinfo=UTC) TEST_UPLOAD_DT2 = datetime.datetime(2003, 1, 9, 15, 43, 01, tzinfo=UTC)
class PatchedClient(APIClient):
"""
Patch DRF's APIClient to avoid a unicode error on file upload.
Famous last words: This is a *temporary* fix that we should be
able to remove once we upgrade Django past 1.4.
"""
def request(self, *args, **kwargs):
"""Construct an API request. """
# DRF's default test client implementation uses `six.text_type()`
# to convert the CONTENT_TYPE to `unicode`. In Django 1.4,
# this causes a `UnicodeDecodeError` when Django parses a multipart
# upload.
#
# This is the DRF code we're working around:
# https://github.com/tomchristie/django-rest-framework/blob/3.1.3/rest_framework/compat.py#L227
#
# ... and this is the Django code that raises the exception:
#
# https://github.com/django/django/blob/1.4.22/django/http/multipartparser.py#L435
#
# Django unhelpfully swallows the exception, so to the application code
# it appears as though the user didn't send any file data.
#
# This appears to be an issue only with requests constructed in the test
# suite, not with the upload code used in production.
#
if isinstance(kwargs.get("CONTENT_TYPE"), basestring):
kwargs["CONTENT_TYPE"] = str(kwargs["CONTENT_TYPE"])
return super(PatchedClient, self).request(*args, **kwargs)
class ProfileImageEndpointTestCase(UserSettingsEventTestMixin, APITestCase): class ProfileImageEndpointTestCase(UserSettingsEventTestMixin, APITestCase):
""" """
Base class / shared infrastructure for tests of profile_image "upload" and Base class / shared infrastructure for tests of profile_image "upload" and
...@@ -111,6 +145,10 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase): ...@@ -111,6 +145,10 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase):
""" """
_view_name = "profile_image_upload" _view_name = "profile_image_upload"
# Use the patched version of the API client to workaround a unicode issue
# with DRF 3.1 and Django 1.4. Remove this after we upgrade Django past 1.4!
client_class = PatchedClient
def check_upload_event_emitted(self, old=None, new=TEST_UPLOAD_DT): def check_upload_event_emitted(self, old=None, new=TEST_UPLOAD_DT):
""" """
Make sure we emit a UserProfile event corresponding to the Make sure we emit a UserProfile event corresponding to the
......
...@@ -183,7 +183,7 @@ def update_account_settings(requesting_user, update, username=None): ...@@ -183,7 +183,7 @@ def update_account_settings(requesting_user, update, username=None):
serializer.save() serializer.save()
if "language_proficiencies" in update: if "language_proficiencies" in update:
new_language_proficiencies = legacy_profile_serializer.data["language_proficiencies"] new_language_proficiencies = update["language_proficiencies"]
emit_setting_changed_event( emit_setting_changed_event(
user=existing_user, user=existing_user,
db_table=existing_user_profile.language_proficiencies.model._meta.db_table, db_table=existing_user_profile.language_proficiencies.model._meta.db_table,
......
...@@ -53,7 +53,7 @@ class UserReadOnlySerializer(serializers.Serializer): ...@@ -53,7 +53,7 @@ class UserReadOnlySerializer(serializers.Serializer):
super(UserReadOnlySerializer, self).__init__(*args, **kwargs) super(UserReadOnlySerializer, self).__init__(*args, **kwargs)
def to_native(self, user): def to_representation(self, user):
""" """
Overwrite to_native to handle custom logic since we are serializing two models as one here Overwrite to_native to handle custom logic since we are serializing two models as one here
:param user: User object :param user: User object
...@@ -152,8 +152,8 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea ...@@ -152,8 +152,8 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
Class that serializes the portion of UserProfile model needed for account information. Class that serializes the portion of UserProfile model needed for account information.
""" """
profile_image = serializers.SerializerMethodField("_get_profile_image") profile_image = serializers.SerializerMethodField("_get_profile_image")
requires_parental_consent = serializers.SerializerMethodField("get_requires_parental_consent") requires_parental_consent = serializers.SerializerMethodField()
language_proficiencies = LanguageProficiencySerializer(many=True, allow_add_remove=True, required=False) language_proficiencies = LanguageProficiencySerializer(many=True, required=False)
class Meta(object): # pylint: disable=missing-docstring class Meta(object): # pylint: disable=missing-docstring
model = UserProfile model = UserProfile
...@@ -165,25 +165,21 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea ...@@ -165,25 +165,21 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
read_only_fields = () read_only_fields = ()
explicit_read_only_fields = ("profile_image", "requires_parental_consent") explicit_read_only_fields = ("profile_image", "requires_parental_consent")
def validate_name(self, attrs, source): def validate_name(self, new_name):
""" Enforce minimum length for name. """ """ Enforce minimum length for name. """
if source in attrs: if len(new_name) < NAME_MIN_LENGTH:
new_name = attrs[source].strip() raise serializers.ValidationError(
if len(new_name) < NAME_MIN_LENGTH: "The name field must be at least {} characters long.".format(NAME_MIN_LENGTH)
raise serializers.ValidationError( )
"The name field must be at least {} characters long.".format(NAME_MIN_LENGTH) return new_name
)
attrs[source] = new_name
return attrs def validate_language_proficiencies(self, value):
def validate_language_proficiencies(self, attrs, source):
""" Enforce all languages are unique. """ """ Enforce all languages are unique. """
language_proficiencies = [language for language in attrs.get(source, [])] language_proficiencies = [language for language in value]
unique_language_proficiencies = set(language.code for language in language_proficiencies) unique_language_proficiencies = set(language["code"] for language in language_proficiencies)
if len(language_proficiencies) != len(unique_language_proficiencies): if len(language_proficiencies) != len(unique_language_proficiencies):
raise serializers.ValidationError("The language_proficiencies field must consist of unique languages") raise serializers.ValidationError("The language_proficiencies field must consist of unique languages")
return attrs return value
def transform_gender(self, user_profile, value): def transform_gender(self, user_profile, value):
""" Converts empty string to None, to indicate not set. Replaced by to_representation in version 3. """ """ Converts empty string to None, to indicate not set. Replaced by to_representation in version 3. """
...@@ -230,3 +226,29 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea ...@@ -230,3 +226,29 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
call the method with a single argument, the user_profile object. call the method with a single argument, the user_profile object.
""" """
return AccountLegacyProfileSerializer.get_profile_image(user_profile, user_profile.user) return AccountLegacyProfileSerializer.get_profile_image(user_profile, user_profile.user)
def update(self, instance, validated_data):
"""
Update the profile, including nested fields.
"""
language_proficiencies = validated_data.pop("language_proficiencies", None)
# Update all fields on the user profile that are writeable,
# except for "language_proficiencies", which we'll update separately
update_fields = set(self.get_writeable_fields()) - set(["language_proficiencies"])
for field_name in update_fields:
default = getattr(instance, field_name)
field_value = validated_data.get(field_name, default)
setattr(instance, field_name, field_value)
instance.save()
# Now update the related language proficiency
if language_proficiencies is not None:
instance.language_proficiencies.all().delete()
instance.language_proficiencies.bulk_create([
LanguageProficiency(user_profile=instance, code=language["code"])
for language in language_proficiencies
])
return instance
...@@ -164,7 +164,10 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase): ...@@ -164,7 +164,10 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
field_errors = context_manager.exception.field_errors field_errors = context_manager.exception.field_errors
self.assertEqual(3, len(field_errors)) self.assertEqual(3, len(field_errors))
self.assertEqual("This field is not editable via this API", field_errors["username"]["developer_message"]) self.assertEqual("This field is not editable via this API", field_errors["username"]["developer_message"])
self.assertIn("Select a valid choice", field_errors["gender"]["developer_message"]) self.assertIn(
"Value \'undecided\' is not valid for field \'gender\'",
field_errors["gender"]["developer_message"]
)
self.assertIn("Valid e-mail address required.", field_errors["email"]["developer_message"]) self.assertIn("Valid e-mail address required.", field_errors["email"]["developer_message"])
@patch('django.core.mail.send_mail') @patch('django.core.mail.send_mail')
......
...@@ -359,16 +359,19 @@ class TestAccountAPI(UserAPITestCase): ...@@ -359,16 +359,19 @@ class TestAccountAPI(UserAPITestCase):
self.assertEqual(404, response.status_code) self.assertEqual(404, response.status_code)
@ddt.data( @ddt.data(
("gender", "f", "not a gender", u"Select a valid choice. not a gender is not one of the available choices."), ("gender", "f", "not a gender", u'"not a gender" is not a valid choice.'),
("level_of_education", "none", u"ȻħȺɍłɇs", u"Select a valid choice. ȻħȺɍłɇs is not one of the available choices."), ("level_of_education", "none", u"ȻħȺɍłɇs", u'"ȻħȺɍłɇs" is not a valid choice.'),
("country", "GB", "XY", u"Select a valid choice. XY is not one of the available choices."), ("country", "GB", "XY", u'"XY" is not a valid choice.'),
("year_of_birth", 2009, "not_an_int", u"Enter a whole number."), ("year_of_birth", 2009, "not_an_int", u"A valid integer is required."),
("name", "bob", "z" * 256, u"Ensure this value has at most 255 characters (it has 256)."), ("name", "bob", "z" * 256, u"Ensure this field has no more than 255 characters."),
("name", u"ȻħȺɍłɇs", "z ", u"The name field must be at least 2 characters long."), ("name", u"ȻħȺɍłɇs", "z ", u"The name field must be at least 2 characters long."),
("goals", "Smell the roses"), ("goals", "Smell the roses"),
("mailing_address", "Sesame Street"), ("mailing_address", "Sesame Street"),
# Note that we store the raw data, so it is up to client to escape the HTML. # Note that we store the raw data, so it is up to client to escape the HTML.
("bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>", "z" * 3001, u"Ensure this value has at most 3000 characters (it has 3001)."), (
"bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>",
"z" * 3001, u"Ensure this field has no more than 3000 characters."
),
# Note that email is tested below, as it is not immediately updated. # Note that email is tested below, as it is not immediately updated.
# Note that language_proficiencies is tested below as there are multiple error and success conditions. # Note that language_proficiencies is tested below as there are multiple error and success conditions.
) )
...@@ -568,10 +571,10 @@ class TestAccountAPI(UserAPITestCase): ...@@ -568,10 +571,10 @@ class TestAccountAPI(UserAPITestCase):
self.assertItemsEqual(response.data["language_proficiencies"], proficiencies) self.assertItemsEqual(response.data["language_proficiencies"], proficiencies)
@ddt.data( @ddt.data(
(u"not_a_list", [{u'non_field_errors': [u'Expected a list of items.']}]), (u"not_a_list", {u'non_field_errors': [u'Expected a list of items but got type "unicode".']}),
([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data']}]), ([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data. Expected a dictionary, but got unicode.']}]),
([{}], [{"code": [u"This field is required."]}]), ([{}], [{"code": [u"This field is required."]}]),
([{u"code": u"invalid_language_code"}], [{'code': [u'Select a valid choice. invalid_language_code is not one of the available choices.']}]), ([{u"code": u"invalid_language_code"}], [{'code': [u'"invalid_language_code" is not a valid choice.']}]),
([{u"code": u"kw"}, {u"code": u"el"}, {u"code": u"kw"}], [u'The language_proficiencies field must consist of unique languages']), ([{u"code": u"kw"}, {u"code": u"el"}, {u"code": u"kw"}], [u'The language_proficiencies field must consist of unique languages']),
) )
@ddt.unpack @ddt.unpack
......
...@@ -9,9 +9,10 @@ from django.conf import settings ...@@ -9,9 +9,10 @@ from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db import IntegrityError from django.db import IntegrityError
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from student.models import User, UserProfile
from django.utils.translation import ugettext_noop from django.utils.translation import ugettext_noop
from student.models import User, UserProfile
from request_cache import get_request_or_stub
from ..errors import ( from ..errors import (
UserAPIInternalError, UserAPIRequestError, UserNotFound, UserNotAuthorized, UserAPIInternalError, UserAPIRequestError, UserNotFound, UserNotAuthorized,
PreferenceValidationError, PreferenceUpdateError PreferenceValidationError, PreferenceUpdateError
...@@ -68,7 +69,17 @@ def get_user_preferences(requesting_user, username=None): ...@@ -68,7 +69,17 @@ def get_user_preferences(requesting_user, username=None):
UserAPIInternalError: the operation failed due to an unexpected error. UserAPIInternalError: the operation failed due to an unexpected error.
""" """
existing_user = _get_user(requesting_user, username, allow_staff=True) existing_user = _get_user(requesting_user, username, allow_staff=True)
user_serializer = UserSerializer(existing_user)
# Django Rest Framework V3 uses the current request to version
# hyperlinked URLS, so we need to retrieve the request and pass
# it in the serializer's context (otherwise we get an AssertionError).
# We're retrieving the request from the cache rather than passing it in
# as an argument because this is an implementation detail of how we're
# serializing data, which we want to encapsulate in the API call.
context = {
"request": get_request_or_stub()
}
user_serializer = UserSerializer(existing_user, context=context)
return user_serializer.data["preferences"] return user_serializer.data["preferences"]
...@@ -356,7 +367,7 @@ def validate_user_preference_serializer(serializer, preference_key, preference_v ...@@ -356,7 +367,7 @@ def validate_user_preference_serializer(serializer, preference_key, preference_v
developer_message = u"Value '{preference_value}' not valid for preference '{preference_key}': {error}".format( developer_message = u"Value '{preference_value}' not valid for preference '{preference_key}': {error}".format(
preference_key=preference_key, preference_value=preference_value, error=serializer.errors preference_key=preference_key, preference_value=preference_value, error=serializer.errors
) )
if serializer.errors["key"]: if "key" in serializer.errors:
user_message = _(u"Invalid user preference key '{preference_key}'.").format( user_message = _(u"Invalid user preference key '{preference_key}'.").format(
preference_key=preference_key preference_key=preference_key
) )
......
...@@ -403,7 +403,7 @@ def get_expected_validation_developer_message(preference_key, preference_value): ...@@ -403,7 +403,7 @@ def get_expected_validation_developer_message(preference_key, preference_value):
preference_key=preference_key, preference_key=preference_key,
preference_value=preference_value, preference_value=preference_value,
error={ error={
"key": [u"Ensure this value has at most 255 characters (it has 256)."] "key": [u"Ensure this field has no more than 255 characters."]
} }
) )
......
...@@ -6,8 +6,8 @@ from .models import UserPreference ...@@ -6,8 +6,8 @@ from .models import UserPreference
class UserSerializer(serializers.HyperlinkedModelSerializer): class UserSerializer(serializers.HyperlinkedModelSerializer):
name = serializers.SerializerMethodField("get_name") name = serializers.SerializerMethodField()
preferences = serializers.SerializerMethodField("get_preferences") preferences = serializers.SerializerMethodField()
def get_name(self, user): def get_name(self, user):
profile = UserProfile.objects.get(user=user) profile = UserProfile.objects.get(user=user)
...@@ -32,9 +32,10 @@ class UserPreferenceSerializer(serializers.HyperlinkedModelSerializer): ...@@ -32,9 +32,10 @@ class UserPreferenceSerializer(serializers.HyperlinkedModelSerializer):
class RawUserPreferenceSerializer(serializers.ModelSerializer): class RawUserPreferenceSerializer(serializers.ModelSerializer):
"""Serializer that generates a raw representation of a user preference.
""" """
user = serializers.PrimaryKeyRelatedField() Serializer that generates a raw representation of a user preference.
"""
user = serializers.PrimaryKeyRelatedField(queryset=User.objects.all())
class Meta(object): # pylint: disable=missing-docstring class Meta(object): # pylint: disable=missing-docstring
model = UserPreference model = UserPreference
...@@ -57,3 +58,11 @@ class ReadOnlyFieldsSerializerMixin(object): ...@@ -57,3 +58,11 @@ class ReadOnlyFieldsSerializerMixin(object):
cls.Meta.read_only_fields tuple. cls.Meta.read_only_fields tuple.
""" """
return getattr(cls.Meta, 'read_only_fields', '') + getattr(cls.Meta, 'explicit_read_only_fields', '') return getattr(cls.Meta, 'read_only_fields', '') + getattr(cls.Meta, 'explicit_read_only_fields', '')
@classmethod
def get_writeable_fields(cls):
"""
Return all fields on this serializer that are writeable.
"""
all_fields = getattr(cls.Meta, 'fields', tuple())
return tuple(set(all_fields) - set(cls.get_read_only_fields()))
""" Common Authentication Handlers used across projects. """ """ Common Authentication Handlers used across projects. """
from rest_framework import authentication from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.exceptions import AuthenticationFailed from rest_framework.exceptions import AuthenticationFailed
from rest_framework.compat import oauth2_provider, provider_now from rest_framework_oauth.compat import oauth2_provider, provider_now
class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthentication): class SessionAuthenticationAllowInactiveUser(SessionAuthentication):
"""Ensure that the user is logged in, but do not require the account to be active. """Ensure that the user is logged in, but do not require the account to be active.
We use this in the special case that a user has created an account, We use this in the special case that a user has created an account,
...@@ -51,7 +52,7 @@ class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthenticatio ...@@ -51,7 +52,7 @@ class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthenticatio
return (user, None) return (user, None)
class OAuth2AuthenticationAllowInactiveUser(authentication.OAuth2Authentication): class OAuth2AuthenticationAllowInactiveUser(OAuth2Authentication):
""" """
This is a temporary workaround while the is_active field on the user is coupled This is a temporary workaround while the is_active field on the user is coupled
with whether or not the user has verified ownership of their claimed email address. with whether or not the user has verified ownership of their claimed email address.
......
"""Fields useful for edX API implementations.""" """Fields useful for edX API implementations."""
from django.core.exceptions import ValidationError from rest_framework.serializers import Field
from rest_framework.serializers import CharField, Field
class ExpandableField(Field): class ExpandableField(Field):
...@@ -18,25 +16,19 @@ class ExpandableField(Field): ...@@ -18,25 +16,19 @@ class ExpandableField(Field):
self.expanded = kwargs.pop('expanded_serializer') self.expanded = kwargs.pop('expanded_serializer')
super(ExpandableField, self).__init__(**kwargs) super(ExpandableField, self).__init__(**kwargs)
def field_to_native(self, obj, field_name): def to_representation(self, obj):
"""Converts obj to a native representation, using the expanded serializer if the context requires it.""" """
if 'expand' in self.context and field_name in self.context['expand']: Return a representation of the field that is either expanded or collapsed.
self.expanded.initialize(self, field_name) """
return self.expanded.field_to_native(obj, field_name) should_expand = self.field_name in self.context.get("expand", [])
else: field = self.expanded if should_expand else self.collapsed
self.collapsed.initialize(self, field_name)
return self.collapsed.field_to_native(obj, field_name)
# Avoid double-binding the field, otherwise we'll get
# an error about the source kwarg being redundant.
if field.source is None:
field.bind(self.field_name, self)
class NonEmptyCharField(CharField): if should_expand:
""" self.expanded.context["expand"] = set(field.context.get("expand", []))
A field that enforces non-emptiness even for partial updates.
This is necessary because prior to version 3, DRF skips validation for empty return field.to_representation(obj)
values. Thus, CharField's min_length and RegexField cannot be used to
enforce this constraint.
"""
def validate(self, value):
super(NonEmptyCharField, self).validate(value)
if not value.strip():
raise ValidationError(self.error_messages["required"])
"""
Django Rest Framework view mixins.
"""
from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework import status
from rest_framework.mixins import CreateModelMixin
from rest_framework.response import Response
class PutAsCreateMixin(CreateModelMixin):
"""
Backwards compatibility with Django Rest Framework v2, which allowed
creation of a new resource using PUT.
"""
def update(self, request, *args, **kwargs):
"""
Create/update course modes for a course.
"""
# First, try to update the existing instance
try:
try:
return super(PutAsCreateMixin, self).update(request, *args, **kwargs)
except Http404:
# If no instance exists yet, create it.
# This is backwards-compatible with the behavior of DRF v2.
return super(PutAsCreateMixin, self).create(request, *args, **kwargs)
# Backwards compatibility with DRF v2 behavior, which would catch model-level
# validation errors and return a 400
except ValidationError as err:
return Response(err.messages, status=status.HTTP_400_BAD_REQUEST)
...@@ -3,6 +3,31 @@ ...@@ -3,6 +3,31 @@
from django.http import Http404 from django.http import Http404
from django.core.paginator import Paginator, InvalidPage from django.core.paginator import Paginator, InvalidPage
from rest_framework.response import Response
from rest_framework import pagination
class DefaultPagination(pagination.PageNumberPagination):
"""
Default paginator for APIs in edx-platform.
This is configured in settings to be automatically used
by any subclass of Django Rest Framework's generic API views.
"""
page_size_query_param = "page_size"
def get_paginated_response(self, data):
"""
Annotate the response with pagination information.
"""
return Response({
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'count': self.page.paginator.count,
'num_pages': self.page.paginator.num_pages,
'results': data
})
def paginate_search_results(object_class, search_results, page_size, page): def paginate_search_results(object_class, search_results, page_size, page):
""" """
......
from rest_framework import pagination, serializers """
Serializers to be used in APIs.
"""
from rest_framework import serializers
class PaginationSerializer(pagination.PaginationSerializer):
"""
Custom PaginationSerializer for openedx.
Adds the following fields:
- num_pages: total number of pages
- current_page: the current page being returned
- start: the index of the first page item within the overall collection
"""
start_page = 1 # django Paginator.page objects have 1-based indexes
num_pages = serializers.Field(source='paginator.num_pages')
current_page = serializers.SerializerMethodField('get_current_page')
start = serializers.SerializerMethodField('get_start')
sort_order = serializers.SerializerMethodField('get_sort_order')
def get_current_page(self, page):
"""Get the current page"""
return page.number
def get_start(self, page):
"""Get the index of the first page item within the overall collection"""
return (self.get_current_page(page) - self.start_page) * page.paginator.per_page
def get_sort_order(self, page): # pylint: disable=unused-argument
"""Get the order by which this collection was sorted"""
return self.context.get('sort_order')
class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer): class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer):
...@@ -54,9 +30,10 @@ class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer): ...@@ -54,9 +30,10 @@ class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer):
super(CollapsedReferenceSerializer, self).__init__(*args, **kwargs) super(CollapsedReferenceSerializer, self).__init__(*args, **kwargs)
self.fields[id_source] = serializers.CharField(read_only=True, source=id_source) self.fields[id_source] = serializers.CharField(read_only=True)
self.fields['url'].view_name = view_name self.fields['url'].view_name = view_name
self.fields['url'].lookup_field = lookup_field self.fields['url'].lookup_field = lookup_field
self.fields['url'].lookup_url_kwarg = lookup_field
class Meta(object): class Meta(object):
"""Defines meta information for the ModelSerializer. """Defines meta information for the ModelSerializer.
......
...@@ -9,6 +9,7 @@ from django.utils.translation import ugettext as _ ...@@ -9,6 +9,7 @@ from django.utils.translation import ugettext as _
from rest_framework import status, response from rest_framework import status, response
from rest_framework.exceptions import APIException from rest_framework.exceptions import APIException
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.request import clone_request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.mixins import RetrieveModelMixin, UpdateModelMixin from rest_framework.mixins import RetrieveModelMixin, UpdateModelMixin
from rest_framework.generics import GenericAPIView from rest_framework.generics import GenericAPIView
...@@ -193,3 +194,23 @@ class RetrievePatchAPIView(RetrieveModelMixin, UpdateModelMixin, GenericAPIView) ...@@ -193,3 +194,23 @@ class RetrievePatchAPIView(RetrieveModelMixin, UpdateModelMixin, GenericAPIView)
add_serializer_errors(serializer, patch, field_errors) add_serializer_errors(serializer, patch, field_errors)
return field_errors return field_errors
def get_object_or_none(self):
"""
Retrieve an object or return None if the object can't be found.
NOTE: This replaces functionality that was removed in Django Rest Framework v3.1.
"""
try:
return self.get_object()
except Http404:
if self.request.method == 'PUT':
# For PUT-as-create operation, we need to ensure that we have
# relevant permissions, as if this was a POST request. This
# will either raise a PermissionDenied exception, or simply
# return None.
self.check_permissions(clone_request(self.request, 'POST'))
else:
# PATCH requests where the object does not exist should still
# return a 404 response.
raise
...@@ -28,7 +28,7 @@ django-ses==0.7.0 ...@@ -28,7 +28,7 @@ django-ses==0.7.0
django-simple-history==1.6.3 django-simple-history==1.6.3
django-storages==1.1.5 django-storages==1.1.5
django-method-override==0.1.0 django-method-override==0.1.0
djangorestframework==2.3.14 djangorestframework>=3.1,<3.2
django==1.4.22 django==1.4.22
elasticsearch==0.4.5 elasticsearch==0.4.5
facebook-sdk==0.4.0 facebook-sdk==0.4.0
......
...@@ -12,6 +12,7 @@ git+https://github.com/edx/django-staticfiles.git@031bdeaea85798b8c284e2a09977df ...@@ -12,6 +12,7 @@ git+https://github.com/edx/django-staticfiles.git@031bdeaea85798b8c284e2a09977df
-e git+https://github.com/edx/django-pipeline.git@88ec8a011e481918fdc9d2682d4017c835acd8be#egg=django-pipeline -e git+https://github.com/edx/django-pipeline.git@88ec8a011e481918fdc9d2682d4017c835acd8be#egg=django-pipeline
-e git+https://github.com/edx/django-wiki.git@cd0b2b31997afccde519fe5b3365e61a9edb143f#egg=django-wiki -e git+https://github.com/edx/django-wiki.git@cd0b2b31997afccde519fe5b3365e61a9edb143f#egg=django-wiki
-e git+https://github.com/edx/django-oauth2-provider.git@0.2.7-fork-edx-5#egg=django-oauth2-provider -e git+https://github.com/edx/django-oauth2-provider.git@0.2.7-fork-edx-5#egg=django-oauth2-provider
-e git+https://github.com/edx/django-rest-framework-oauth.git@f0b503fda8c254a38f97fef802ded4f5fe367f7a#egg=djangorestframework-oauth
-e git+https://github.com/edx/MongoDBProxy.git@25b99097615bda06bd7cdfe5669ed80dc2a7fed0#egg=mongodb_proxy -e git+https://github.com/edx/MongoDBProxy.git@25b99097615bda06bd7cdfe5669ed80dc2a7fed0#egg=mongodb_proxy
git+https://github.com/edx/nltk.git@2.0.6#egg=nltk==2.0.6 git+https://github.com/edx/nltk.git@2.0.6#egg=nltk==2.0.6
-e git+https://github.com/dementrock/pystache_custom.git@776973740bdaad83a3b029f96e415a7d1e8bec2f#egg=pystache_custom-dev -e git+https://github.com/dementrock/pystache_custom.git@776973740bdaad83a3b029f96e415a7d1e8bec2f#egg=pystache_custom-dev
...@@ -40,13 +41,13 @@ git+https://github.com/edx/rfc6266.git@v0.0.5-edx#egg=rfc6266==0.0.5-edx ...@@ -40,13 +41,13 @@ git+https://github.com/edx/rfc6266.git@v0.0.5-edx#egg=rfc6266==0.0.5-edx
-e git+https://github.com/edx/event-tracking.git@0.2.0#egg=event-tracking -e git+https://github.com/edx/event-tracking.git@0.2.0#egg=event-tracking
-e git+https://github.com/edx-solutions/django-splash.git@7579d052afcf474ece1239153cffe1c89935bc4f#egg=django-splash -e git+https://github.com/edx-solutions/django-splash.git@7579d052afcf474ece1239153cffe1c89935bc4f#egg=django-splash
-e git+https://github.com/edx/acid-block.git@e46f9cda8a03e121a00c7e347084d142d22ebfb7#egg=acid-xblock -e git+https://github.com/edx/acid-block.git@e46f9cda8a03e121a00c7e347084d142d22ebfb7#egg=acid-xblock
-e git+https://github.com/edx/edx-ora2.git@release-2015-08-25T16.16#egg=edx-ora2 -e git+https://github.com/edx/edx-ora2.git@release-2015-09-16T15.28#egg=edx-ora2
-e git+https://github.com/edx/edx-submissions.git@9538ee8a971d04dc1cb05e88f6aa0c36b224455c#egg=edx-submissions -e git+https://github.com/edx/edx-submissions.git@0.1.0#egg=edx-submissions
-e git+https://github.com/edx/opaque-keys.git@27dc382ea587483b1e3889a3d19cbd90b9023a06#egg=opaque-keys -e git+https://github.com/edx/opaque-keys.git@27dc382ea587483b1e3889a3d19cbd90b9023a06#egg=opaque-keys
git+https://github.com/edx/ease.git@release-2015-07-14#egg=ease==0.1.3 git+https://github.com/edx/ease.git@release-2015-07-14#egg=ease==0.1.3
git+https://github.com/edx/i18n-tools.git@v0.1.3#egg=i18n-tools==v0.1.3 git+https://github.com/edx/i18n-tools.git@v0.1.3#egg=i18n-tools==v0.1.3
git+https://github.com/edx/edx-oauth2-provider.git@0.5.7#egg=oauth2-provider==0.5.7 git+https://github.com/edx/edx-oauth2-provider.git@0.5.7#egg=oauth2-provider==0.5.7
-e git+https://github.com/edx/edx-val.git@v0.0.5#egg=edx-val -e git+https://github.com/edx/edx-val.git@0.0.6#egg=edx-val
-e git+https://github.com/pmitros/RecommenderXBlock.git@518234bc354edbfc2651b9e534ddb54f96080779#egg=recommender-xblock -e git+https://github.com/pmitros/RecommenderXBlock.git@518234bc354edbfc2651b9e534ddb54f96080779#egg=recommender-xblock
-e git+https://github.com/edx/edx-search.git@release-2015-09-11a#egg=edx-search -e git+https://github.com/edx/edx-search.git@release-2015-09-11a#egg=edx-search
-e git+https://github.com/edx/edx-milestones.git@9b44a37edc3d63a23823c21a63cdd53ef47a7aa4#egg=edx-milestones -e git+https://github.com/edx/edx-milestones.git@9b44a37edc3d63a23823c21a63cdd53ef47a7aa4#egg=edx-milestones
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment