Commit 00473d44 by Brian Beggs

Merge pull request #9929 from edx/bbeggs/merge-DRF-3.1

Merge DRF 3.1 in to master
parents 78f8fc39 53524efb
......@@ -44,7 +44,10 @@ from lms.envs.common import (
PROFILE_IMAGE_SECRET_KEY, PROFILE_IMAGE_MIN_BYTES, PROFILE_IMAGE_MAX_BYTES,
# The following setting is included as it is used to check whether to
# display credit eligibility table on the CMS or not.
ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY
ENABLE_CREDIT_ELIGIBILITY, YOUTUBE_API_KEY,
# Django REST framework configuration
REST_FRAMEWORK,
)
from path import Path as path
from warnings import simplefilter
......
......@@ -6,7 +6,7 @@ from django.test.utils import override_settings
from django.test.client import RequestFactory
from django.conf import settings
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.exceptions import PermissionDenied
from cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf
......@@ -24,7 +24,7 @@ class CrossDomainAuthTest(TestCase):
def test_perform_csrf_referer_check(self):
request = self._fake_request()
with self.assertRaisesRegexp(AuthenticationFailed, 'CSRF'):
with self.assertRaisesRegexp(PermissionDenied, 'CSRF'):
self.auth.enforce_csrf(request)
@patch.dict(settings.FEATURES, {
......
......@@ -11,7 +11,7 @@ from enrollment.errors import (
CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError,
CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute
)
from enrollment.serializers import CourseEnrollmentSerializer, CourseField
from enrollment.serializers import CourseEnrollmentSerializer, CourseSerializer
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from student.models import (
CourseEnrollment, NonExistentCourseError, EnrollmentClosedError,
......@@ -35,9 +35,30 @@ def get_course_enrollments(user_id):
"""
qset = CourseEnrollment.objects.filter(
user__username=user_id, is_active=True
user__username=user_id,
is_active=True
).order_by('created')
return CourseEnrollmentSerializer(qset).data
enrollments = CourseEnrollmentSerializer(qset, many=True).data
# Find deleted courses and filter them out of the results
deleted = []
valid = []
for enrollment in enrollments:
if enrollment.get("course_details") is not None:
valid.append(enrollment)
else:
deleted.append(enrollment)
if deleted:
log.warning(
(
u"Course enrollments for user %s reference "
u"courses that do not exist (this can occur if a course is deleted)."
), user_id,
)
return valid
def get_course_enrollment(username, course_id):
......@@ -271,4 +292,4 @@ def get_course_enrollment_info(course_id, include_expired=False):
log.warning(msg)
raise CourseNotFoundError(msg)
else:
return CourseField().to_native(course, include_expired=include_expired)
return CourseSerializer(course, include_expired=include_expired).data
......@@ -30,32 +30,36 @@ class StringListField(serializers.CharField):
return [int(item) for item in items]
class CourseField(serializers.RelatedField):
"""Read-Only representation of course enrollment information.
class CourseSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""
Serialize a course descriptor and related information.
"""
Aggregates course information from the CourseDescriptor as well as the Course Modes configured
for enrolling in the course.
course_id = serializers.CharField(source="id")
enrollment_start = serializers.DateTimeField(format=None)
enrollment_end = serializers.DateTimeField(format=None)
course_start = serializers.DateTimeField(source="start", format=None)
course_end = serializers.DateTimeField(source="end", format=None)
invite_only = serializers.BooleanField(source="invitation_only")
course_modes = serializers.SerializerMethodField()
"""
def __init__(self, *args, **kwargs):
self.include_expired = kwargs.pop("include_expired", False)
super(CourseSerializer, self).__init__(*args, **kwargs)
def to_native(self, course, **kwargs):
course_modes = ModeSerializer(
CourseMode.modes_for_course(
course.id,
include_expired=kwargs.get('include_expired', False),
only_selectable=False
)
).data
return {
'course_id': unicode(course.id),
'enrollment_start': course.enrollment_start,
'enrollment_end': course.enrollment_end,
'course_start': course.start,
'course_end': course.end,
'invite_only': course.invitation_only,
'course_modes': course_modes,
}
def get_course_modes(self, obj):
"""
Retrieve course modes associated with the course.
"""
course_modes = CourseMode.modes_for_course(
obj.id,
include_expired=self.include_expired,
only_selectable=False
)
return [
ModeSerializer(mode).data
for mode in course_modes
]
class CourseEnrollmentSerializer(serializers.ModelSerializer):
......@@ -65,34 +69,9 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer):
the Course Descriptor and course modes, to give a complete representation of course enrollment.
"""
course_details = serializers.SerializerMethodField('get_course_details')
course_details = CourseSerializer(source="course_overview")
user = serializers.SerializerMethodField('get_username')
@property
def data(self):
serialized_data = super(CourseEnrollmentSerializer, self).data
# filter the results with empty courses 'course_details'
if isinstance(serialized_data, dict):
if serialized_data.get('course_details') is None:
return None
return serialized_data
return [enrollment for enrollment in serialized_data if enrollment.get('course_details')]
def get_course_details(self, model):
if model.course is None:
msg = u"Course '{0}' does not exist (maybe deleted), in which User (user_id: '{1}') is enrolled.".format(
model.course_id,
model.user.id
)
log.warning(msg)
return None
field = CourseField()
return field.to_native(model.course)
def get_username(self, model):
"""Retrieves the username from the associated model."""
return model.username
......
......@@ -1038,7 +1038,7 @@ class EnrollmentCrossDomainTest(ModuleStoreTestCase):
@cross_domain_config
def test_cross_domain_missing_csrf(self, *args): # pylint: disable=unused-argument
resp = self._cross_domain_post('invalid_csrf_token')
self.assertEqual(resp.status_code, 401)
self.assertEqual(resp.status_code, 403)
def _get_csrf_cookie(self):
"""Retrieve the cross-domain CSRF cookie. """
......
......@@ -5,10 +5,18 @@ This module requires that :class:`request_cache.middleware.RequestCache`
is installed in order to clear the cache after each request.
"""
import logging
from urlparse import urlparse
from django.conf import settings
from django.test.client import RequestFactory
from request_cache import middleware
log = logging.getLogger(__name__)
def get_cache(name):
"""
Return the request cache named ``name``.
......@@ -26,3 +34,38 @@ def get_request():
Return the current request.
"""
return middleware.RequestCache.get_current_request()
def get_request_or_stub():
"""
Return the current request or a stub request.
If called outside the context of a request, construct a fake
request that can be used to build an absolute URI.
This is useful in cases where we need to pass in a request object
but don't have an active request (for example, in test cases).
"""
request = get_request()
if request is None:
log.warning(
"Could not retrieve the current request. "
"A stub request will be created instead using settings.SITE_NAME. "
"This should be used *only* in test cases, never in production!"
)
# The settings SITE_NAME may contain a port number, so we need to
# parse the full URL.
full_url = "http://{site_name}".format(site_name=settings.SITE_NAME)
parsed_url = urlparse(full_url)
# Construct the fake request. This can be used to construct absolute
# URIs to other paths.
return RequestFactory(
SERVER_NAME=parsed_url.hostname,
SERVER_PORT=parsed_url.port or 80,
).get("/")
else:
return request
"""
Tests for the request cache.
"""
from django.conf import settings
from django.test import TestCase
from request_cache import get_request_or_stub
class TestRequestCache(TestCase):
"""
Tests for the request cache.
"""
def test_get_request_or_stub(self):
# Outside the context of the request, we should still get a request
# that allows us to build an absolute URI.
stub = get_request_or_stub()
expected_url = "http://{site_name}/foobar".format(site_name=settings.SITE_NAME)
self.assertEqual(stub.build_absolute_uri("foobar"), expected_url)
......@@ -406,7 +406,7 @@ class BrowseTopicsTest(TeamsTabBase):
)
create_team_page.submit_form()
team_page = TeamPage(self.browser, self.course_id)
self.assertTrue(team_page.is_browser_on_page)
self.assertTrue(team_page.is_browser_on_page())
team_page.click_all_topics()
self.assertTrue(self.topics_page.is_browser_on_page())
self.topics_page.wait_for_ajax()
......
......@@ -18,7 +18,12 @@ class CourseModeSerializer(serializers.ModelSerializer):
""" CourseMode serializer. """
name = serializers.CharField(source='mode_slug')
price = serializers.IntegerField(source='min_price')
expires = serializers.DateTimeField(source='expiration_datetime', required=False, blank=True)
expires = serializers.DateTimeField(
source='expiration_datetime',
required=False,
allow_null=True,
format=None
)
def get_identity(self, data):
try:
......@@ -56,8 +61,8 @@ class CourseSerializer(serializers.Serializer):
""" Course serializer. """
id = serializers.CharField(validators=[validate_course_id]) # pylint: disable=invalid-name
name = serializers.CharField(read_only=True)
verification_deadline = serializers.DateTimeField(blank=True)
modes = CourseModeSerializer(many=True, allow_add_remove=True)
verification_deadline = serializers.DateTimeField(format=None, allow_null=True, required=False)
modes = CourseModeSerializer(many=True)
def validate(self, attrs):
""" Ensure the verification deadline occurs AFTER the course mode enrollment deadlines. """
......@@ -68,7 +73,7 @@ class CourseSerializer(serializers.Serializer):
# Find the earliest upgrade deadline
for mode in attrs['modes']:
expires = mode.expiration_datetime
expires = mode.get("expiration_datetime")
if expires:
# If we don't already have an upgrade_deadline value, use datetime.max so that we can actually
# complete the comparison.
......@@ -82,9 +87,28 @@ class CourseSerializer(serializers.Serializer):
return attrs
def restore_object(self, attrs, instance=None):
if instance is None:
return Course(attrs['id'], attrs['modes'], attrs['verification_deadline'])
def create(self, validated_data):
"""Create course modes for a course. """
course = Course(
validated_data["id"],
self._new_course_mode_models(validated_data["modes"]),
verification_deadline=validated_data["verification_deadline"]
)
course.save()
return course
def update(self, instance, validated_data):
"""Update course modes for an existing course. """
validated_data["modes"] = self._new_course_mode_models(validated_data["modes"])
instance.update(attrs)
instance.update(validated_data)
instance.save()
return instance
@staticmethod
def _new_course_mode_models(modes_data):
"""Convert validated course mode data to CourseMode objects. """
return [
CourseMode(**modes_dict)
for modes_dict in modes_data
]
......@@ -2,7 +2,8 @@
import logging
from django.http import Http404
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView
from rest_framework.permissions import IsAuthenticated
......@@ -10,6 +11,7 @@ from commerce.api.v1.models import Course
from commerce.api.v1.permissions import ApiKeyOrModelPermission
from commerce.api.v1.serializers import CourseSerializer
from course_modes.models import CourseMode
from openedx.core.lib.api.mixins import PutAsCreateMixin
log = logging.getLogger(__name__)
......@@ -19,12 +21,13 @@ class CourseListView(ListAPIView):
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
permission_classes = (IsAuthenticated,)
serializer_class = CourseSerializer
pagination_class = None
def get_queryset(self):
return Course.iterator()
return list(Course.iterator())
class CourseRetrieveUpdateView(RetrieveUpdateAPIView):
class CourseRetrieveUpdateView(PutAsCreateMixin, RetrieveUpdateAPIView):
""" Retrieve, update, or create courses/modes. """
lookup_field = 'id'
lookup_url_kwarg = 'course_id'
......@@ -33,6 +36,11 @@ class CourseRetrieveUpdateView(RetrieveUpdateAPIView):
permission_classes = (ApiKeyOrModelPermission,)
serializer_class = CourseSerializer
# Django Rest Framework v3 requires that we provide a queryset.
# Note that we're overriding `get_object()` below to return a `Course`
# rather than a CourseMode, so this isn't really used.
queryset = CourseMode.objects.all()
def get_object(self, queryset=None):
course_id = self.kwargs.get(self.lookup_url_kwarg)
course = Course.get(course_id)
......
......@@ -11,11 +11,11 @@ class CourseSerializer(serializers.Serializer):
id = serializers.CharField() # pylint: disable=invalid-name
name = serializers.CharField(source='display_name')
category = serializers.CharField()
org = serializers.SerializerMethodField('get_org')
run = serializers.SerializerMethodField('get_run')
course = serializers.SerializerMethodField('get_course')
uri = serializers.SerializerMethodField('get_uri')
image_url = serializers.SerializerMethodField('get_image_url')
org = serializers.SerializerMethodField()
run = serializers.SerializerMethodField()
course = serializers.SerializerMethodField()
uri = serializers.SerializerMethodField()
image_url = serializers.SerializerMethodField()
start = serializers.DateTimeField()
end = serializers.DateTimeField()
......
......@@ -36,6 +36,23 @@ class CourseViewTestsMixin(object):
"""
view = None
raw_grader = [
{
"min_count": 24,
"weight": 0.2,
"type": "Homework",
"drop_count": 0,
"short_label": "HW"
},
{
"min_count": 4,
"weight": 0.8,
"type": "Exam",
"drop_count": 0,
"short_label": "Exam"
}
]
def setUp(self):
super(CourseViewTestsMixin, self).setUp()
self.create_user_and_access_token()
......@@ -51,22 +68,7 @@ class CourseViewTestsMixin(object):
@classmethod
def create_course_data(cls):
cls.invalid_course_id = 'foo/bar/baz'
cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=[
{
"min_count": 24,
"weight": 0.2,
"type": "Homework",
"drop_count": 0,
"short_label": "HW"
},
{
"min_count": 4,
"weight": 0.8,
"type": "Exam",
"drop_count": 0,
"short_label": "Exam"
}
])
cls.course = CourseFactory.create(display_name='An Introduction to API Testing', raw_grader=cls.raw_grader)
cls.course_id = unicode(cls.course.id)
with cls.store.bulk_operations(cls.course.id, emit_signals=False):
cls.sequential = ItemFactory.create(
......@@ -408,6 +410,55 @@ class CourseGradingPolicyTests(CourseDetailTestMixin, CourseViewTestsMixin, Shar
self.assertListEqual(response.data, expected)
class CourseGradingPolicyMissingFieldsTests(CourseDetailTestMixin, CourseViewTestsMixin, SharedModuleStoreTestCase):
view = 'course_structure_api:v0:grading_policy'
# Update the raw grader to have missing keys
raw_grader = [
{
"min_count": 24,
"weight": 0.2,
"type": "Homework",
"drop_count": 0,
"short_label": "HW"
},
{
# Deleted "min_count" key
"weight": 0.8,
"type": "Exam",
"drop_count": 0,
"short_label": "Exam"
}
]
@classmethod
def setUpClass(cls):
super(CourseGradingPolicyMissingFieldsTests, cls).setUpClass()
cls.create_course_data()
def test_get(self):
"""
The view should return grading policy for a course.
"""
response = super(CourseGradingPolicyMissingFieldsTests, self).test_get()
expected = [
{
"count": 24,
"weight": 0.2,
"assignment_type": "Homework",
"dropped": 0
},
{
"count": None,
"weight": 0.8,
"assignment_type": "Exam",
"dropped": 0
}
]
self.assertListEqual(response.data, expected)
#####################################################################################
#
# The following Mixins/Classes collectively test the CourseBlocksAndNavigation view.
......
......@@ -6,7 +6,8 @@ import logging
from django.conf import settings
from django.http import Http404
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.exceptions import AuthenticationFailed, ParseError
from rest_framework.generics import RetrieveAPIView, ListAPIView
from rest_framework.permissions import IsAuthenticated
......@@ -21,7 +22,6 @@ from courseware.access import has_access
from courseware.model_data import FieldDataCache
from courseware.module_render import get_module_for_descriptor
from openedx.core.lib.api.view_utils import view_course_access, view_auth_classes
from openedx.core.lib.api.serializers import PaginationSerializer
from openedx.core.djangoapps.content.course_structures.api.v0 import api, errors
from student.roles import CourseInstructorRole, CourseStaffRole
from util.module_utils import get_dynamic_descriptor_children
......@@ -157,9 +157,6 @@ class CourseList(CourseViewMixin, ListAPIView):
* end: The course end date. If course end date is not specified, the
value is null.
"""
paginate_by = 10
paginate_by_param = 'page_size'
pagination_serializer_class = PaginationSerializer
serializer_class = serializers.CourseSerializer
def get_queryset(self):
......
......@@ -25,7 +25,6 @@ from xmodule.modulestore.django import modulestore
from xmodule.modulestore.exceptions import ItemNotFoundError
from .models import StudentModule
from .module_render import get_module_for_descriptor
from submissions import api as sub_api # installed from the edx-submissions repository
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
from openedx.core.djangoapps.signals.signals import GRADES_UPDATED
......@@ -349,8 +348,13 @@ def _grade(student, request, course, keep_raw_scores, field_data_cache, scores_c
# Dict of item_ids -> (earned, possible) point tuples. This *only* grabs
# scores that were registered with the submissions API, which for the moment
# means only openassessment (edx-ora2)
# We need to import this here to avoid a circular dependency of the form:
# XBlock --> submissions --> Django Rest Framework error strings -->
# Django translation --> ... --> courseware --> submissions
from submissions import api as sub_api # installed from the edx-submissions repository
submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id))
max_scores_cache = MaxScoresCache.create_for_course(course)
# For the moment, we have to get scorable_locations from field_data_cache
# and not from scores_client, because scores_client is ignorant of things
# in the submissions API. As a further refactoring step, submissions should
......@@ -565,7 +569,12 @@ def _progress_summary(student, request, course, field_data_cache=None, scores_cl
course_module = getattr(course_module, '_x_module', course_module)
# We need to import this here to avoid a circular dependency of the form:
# XBlock --> submissions --> Django Rest Framework error strings -->
# Django translation --> ... --> courseware --> submissions
from submissions import api as sub_api # installed from the edx-submissions repository
submissions_scores = sub_api.get_scores(course.id.to_deprecated_string(), anonymous_id_for_user(student, course.id))
max_scores_cache = MaxScoresCache.create_for_course(course)
# For the moment, we have to get scorable_locations from field_data_cache
# and not from scores_client, because scores_client is ignorant of things
......
......@@ -560,7 +560,7 @@ def create_thread(request, thread_data):
if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
serializer.save()
cc_thread = serializer.object
cc_thread = serializer.instance
thread_created.send(sender=None, user=user, post=cc_thread)
api_thread = serializer.data
_do_extra_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context)
......@@ -606,7 +606,7 @@ def create_comment(request, comment_data):
if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
serializer.save()
cc_comment = serializer.object
cc_comment = serializer.instance
comment_created.send(sender=None, user=request.user, post=cc_comment)
api_comment = serializer.data
_do_extra_actions(api_comment, cc_comment, comment_data.keys(), actions_form, context)
......
"""
Discussion API pagination support
"""
from rest_framework.pagination import BasePaginationSerializer, NextPageField, PreviousPageField
class _PaginationSerializer(BasePaginationSerializer):
"""
A pagination serializer without the count field, because the Comments
Service does not return result counts
"""
next = NextPageField(source="*")
previous = PreviousPageField(source="*")
from rest_framework.utils.urls import replace_query_param
class _Page(object):
......@@ -52,7 +43,25 @@ def get_paginated_data(request, results, page_num, per_page):
previous: The URL for the previous page
results: The results on this page
"""
return _PaginationSerializer(
instance=_Page(results, page_num, per_page),
context={"request": request}
).data
# Note: Previous versions of this function used Django Rest Framework's
# paginated serializer. With the upgrade to DRF 3.1, paginated serializers
# have been removed. We *could* use DRF's paginator classes, but there are
# some slight differences between how DRF does pagination and how we're doing
# pagination here. (For example, we respond with a next_url param even if
# there is only one result on the current page.) To maintain backwards
# compatability, we simulate the behavior that DRF used to provide.
page = _Page(results, page_num, per_page)
next_url, previous_url = None, None
base_url = request.build_absolute_uri()
if page.has_next():
next_url = replace_query_param(base_url, "page", page.next_page_number())
if page.has_previous():
previous_url = replace_query_param(base_url, "page", page.previous_page_number())
return {
"next": next_url,
"previous": previous_url,
"results": results,
}
......@@ -1497,8 +1497,9 @@ class CreateThreadTest(
self.assertEqual(actual_post_data["group_id"], [str(cohort.id)])
else:
self.assertNotIn("group_id", actual_post_data)
except ValidationError:
self.assertTrue(expected_error)
except ValidationError as ex:
if not expected_error:
self.fail("Unexpected validation error: {}".format(ex))
def test_following(self):
self.register_post_thread_response({"id": "test_id"})
......@@ -2239,7 +2240,7 @@ class UpdateThreadTest(
update_thread(self.request, "test_thread", {"raw_body": ""})
self.assertEqual(
assertion.exception.message_dict,
{"raw_body": ["This field is required."]}
{"raw_body": ["This field may not be blank."]}
)
......
......@@ -523,9 +523,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
data = self.minimal_data.copy()
data.update({field: value for field in ["topic_id", "title", "raw_body"]})
serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request))
self.assertFalse(serializer.is_valid())
self.assertEqual(
serializer.errors,
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]}
{field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]}
)
def test_create_type(self):
......@@ -592,9 +593,10 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
partial=True,
context=get_context(self.course, self.request)
)
self.assertFalse(serializer.is_valid())
self.assertEqual(
serializer.errors,
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]}
{field: ["This field may not be blank."] for field in ["topic_id", "title", "raw_body"]}
)
def test_update_course_id(self):
......@@ -604,6 +606,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
partial=True,
context=get_context(self.course, self.request)
)
self.assertFalse(serializer.is_valid())
self.assertEqual(
serializer.errors,
{"course_id": ["This field is not allowed in an update."]}
......@@ -769,7 +772,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
data["parent_id"] = None
serializer = CommentSerializer(data=data, context=context)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {"parent_id": ["Comment level is too deep."]})
self.assertEqual(serializer.errors, {"non_field_errors": ["Comment level is too deep."]})
def test_create_missing_field(self):
for field in self.minimal_data:
......@@ -855,9 +858,10 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
partial=True,
context=get_context(self.course, self.request)
)
self.assertFalse(serializer.is_valid())
self.assertEqual(
serializer.errors,
{"raw_body": ["This field is required."]}
{"raw_body": ["This field may not be blank."]}
)
@ddt.data("thread_id", "parent_id")
......@@ -868,6 +872,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, SharedModul
partial=True,
context=get_context(self.course, self.request)
)
self.assertFalse(serializer.is_valid())
self.assertEqual(
serializer.errors,
{field: ["This field is not allowed in an update."]}
......
......@@ -571,7 +571,7 @@ class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTest
content_type="application/json"
)
expected_response_data = {
"field_errors": {"title": {"developer_message": "This field is required."}}
"field_errors": {"title": {"developer_message": "This field may not be blank."}}
}
self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content)
......@@ -690,7 +690,7 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
200,
{
"results": expected_comments,
"next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format(
"next": "http://testserver/api/discussion/v1/comments/?page=2&thread_id={}".format(
self.thread_id
),
"previous": None,
......@@ -946,7 +946,7 @@ class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTes
content_type="application/json"
)
expected_response_data = {
"field_errors": {"raw_body": {"developer_message": "This field is required."}}
"field_errors": {"raw_body": {"developer_message": "This field may not be blank."}}
}
self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content)
......
......@@ -3,7 +3,8 @@ Discussion API views
"""
from django.core.exceptions import ValidationError
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
......
......@@ -32,8 +32,6 @@ from xmodule.modulestore.tests.factories import check_mongo_calls
from xmodule.modulestore.django import modulestore
from xmodule.modulestore import ModuleStoreEnum
from teams.tests.factories import CourseTeamFactory
log = logging.getLogger(__name__)
......@@ -1290,6 +1288,7 @@ class TeamsPermissionsTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSe
topic_id='topic_id',
discussion_topic_id=self.team_commentable_id
)
self.team.add_user(self.student_in_team)
# Dummy commentable ID not linked to a team
......
......@@ -717,6 +717,7 @@ class InlineDiscussionContextTestCase(ModuleStoreTestCase):
topic_id='topic_id',
discussion_topic_id=self.discussion_topic_id
)
self.team.add_user(self.user) # pylint: disable=no-member
def test_context_can_be_standalone(self, mock_request):
......@@ -1093,7 +1094,9 @@ class InlineDiscussionTestCase(ModuleStoreTestCase):
course_id=self.course.id,
discussion_topic_id=self.discussion1.discussion_id
)
team.add_user(self.student) # pylint: disable=no-member
response = self.send_request(mock_request)
self.assertEqual(mock_request.call_args[1]['params']['context'], ThreadContext.STANDALONE)
self.verify_response(response)
......
......@@ -149,6 +149,8 @@ class NonCohortedTopicGroupIdTestMixin(GroupIdAssertionMixin):
def test_team_discussion_id_not_cohorted(self, mock_request):
team = CourseTeamFactory(course_id=self.course.id)
team.add_user(self.student) # pylint: disable=no-member
self.call_view(mock_request, team.discussion_topic_id, self.student, None)
self._assert_comments_service_called_without_group_id(mock_request)
......@@ -31,7 +31,7 @@ class CoursesWithFriends(generics.ListAPIView):
serializer_class = serializers.CoursesWithFriendsSerializer
def list(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.GET, files=request.FILES)
serializer = self.get_serializer(data=request.GET)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
......@@ -61,4 +61,5 @@ class CoursesWithFriends(generics.ListAPIView):
and is_mobile_available_for_user(self.request.user, enrollment.course)
]
return Response(CourseEnrollmentSerializer(courses, context={'request': request}).data)
serializer = CourseEnrollmentSerializer(courses, context={'request': request}, many=True)
return Response(serializer.data)
......@@ -40,7 +40,7 @@ class FriendsInCourse(generics.ListAPIView):
serializer_class = serializers.FriendsInCourseSerializer
def list(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.GET, files=request.FILES)
serializer = self.get_serializer(data=request.GET)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
......
......@@ -45,7 +45,7 @@ class Groups(generics.CreateAPIView, mixins.DestroyModelMixin):
serializer_class = serializers.GroupSerializer
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
serializer = self.get_serializer(data=request.DATA)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
try:
......@@ -106,12 +106,12 @@ class GroupsMembers(generics.CreateAPIView, mixins.DestroyModelMixin):
serializer_class = serializers.GroupsMembersSerializer
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
serializer = self.get_serializer(data=request.DATA)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
graph = facebook_graph_api()
url = settings.FACEBOOK_API_VERSION + '/' + kwargs['group_id'] + "/members"
member_ids = serializer.object['member_ids'].split(',')
member_ids = serializer.data['member_ids'].split(',')
response = {}
for member_id in member_ids:
try:
......
......@@ -8,4 +8,4 @@ class UserSharingSerializar(serializers.Serializer):
"""
Serializes user social settings
"""
share_with_facebook_friends = serializers.BooleanField(required=True, default=False)
share_with_facebook_friends = serializers.BooleanField(required=True)
......@@ -39,9 +39,9 @@ class UserSharing(generics.ListCreateAPIView):
serializer_class = serializers.UserSharingSerializar
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
serializer = self.get_serializer(data=request.DATA)
if serializer.is_valid():
value = serializer.object['share_with_facebook_friends']
value = serializer.data['share_with_facebook_friends']
set_user_preference(request.user, "share_with_facebook_friends", value)
return self.get(request, *args, **kwargs)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
......
......@@ -40,7 +40,7 @@ def get_friends_from_facebook(serializer):
the error message.
"""
try:
graph = facebook.GraphAPI(serializer.object['oauth_token'])
graph = facebook.GraphAPI(serializer.data['oauth_token'])
friends = graph.request(settings.FACEBOOK_API_VERSION + "/me/friends")
return get_pagination(friends)
except facebook.GraphAPIError, ex:
......
......@@ -15,7 +15,7 @@ from xmodule.course_module import DEFAULT_START_DATE
class CourseOverviewField(serializers.RelatedField):
"""Custom field to wrap a CourseDescriptor object. Read-only."""
def to_native(self, course_overview):
def to_representation(self, course_overview):
course_id = unicode(course_overview.id)
request = self.context.get('request', None)
if request:
......@@ -77,8 +77,8 @@ class CourseEnrollmentSerializer(serializers.ModelSerializer):
"""
Serializes CourseEnrollment models
"""
course = CourseOverviewField(source="course_overview")
certificate = serializers.SerializerMethodField('get_certificate')
course = CourseOverviewField(source="course_overview", read_only=True)
certificate = serializers.SerializerMethodField()
def get_certificate(self, model):
"""Returns the information about the user's certificate in the course."""
......@@ -100,7 +100,7 @@ class UserSerializer(serializers.HyperlinkedModelSerializer):
"""
Serializes User models
"""
name = serializers.Field(source='profile.name')
name = serializers.ReadOnlyField(source='profile.name')
course_enrollments = serializers.HyperlinkedIdentityField(
view_name='courseenrollment-detail',
lookup_field='username'
......
......@@ -251,6 +251,14 @@ class UserCourseEnrollmentsList(generics.ListAPIView):
serializer_class = CourseEnrollmentSerializer
lookup_field = 'username'
# In Django Rest Framework v3, there is a default pagination
# class that transmutes the response data into a dictionary
# with pagination information. The original response data (a list)
# is stored in a "results" value of the dictionary.
# For backwards compatibility with the existing API, we disable
# the default behavior by setting the pagination_class to None.
pagination_class = None
def get_queryset(self):
enrollments = self.queryset.filter(
user__username=self.kwargs['username'],
......
......@@ -23,9 +23,9 @@ class NotifierUserSerializer(serializers.ModelSerializer):
* course_groups
* roles__permissions
"""
name = serializers.SerializerMethodField("get_name")
preferences = serializers.SerializerMethodField("get_preferences")
course_info = serializers.SerializerMethodField("get_course_info")
name = serializers.SerializerMethodField()
preferences = serializers.SerializerMethodField()
course_info = serializers.SerializerMethodField()
def get_name(self, user):
return user.profile.name
......
from django.contrib.auth.models import User
from rest_framework.viewsets import ReadOnlyModelViewSet
from rest_framework.response import Response
from rest_framework import pagination
from notification_prefs import NOTIFICATION_PREF_KEY
from notifier_api.serializers import NotifierUserSerializer
from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
class NotifierPaginator(pagination.PageNumberPagination):
"""
Paginator for the notifier API.
"""
page_size = 10
page_size_query_param = "page_size"
def get_paginated_response(self, data):
"""
Construct a response with pagination information.
"""
return Response({
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'count': self.page.paginator.count,
'results': data
})
class NotifierUsersViewSet(ReadOnlyModelViewSet):
"""
An endpoint that the notifier can use to retrieve users who have enabled
......@@ -14,8 +35,7 @@ class NotifierUsersViewSet(ReadOnlyModelViewSet):
"""
permission_classes = (ApiKeyHeaderPermission,)
serializer_class = NotifierUserSerializer
paginate_by = 10
paginate_by_param = "page_size"
pagination_class = NotifierPaginator
# See NotifierUserSerializer for notes about related tables
queryset = User.objects.filter(
......
......@@ -3,7 +3,6 @@
from datetime import datetime
from uuid import uuid4
import pytz
from datetime import datetime
from model_utils import FieldTracker
from django.core.exceptions import ObjectDoesNotExist
......
......@@ -10,6 +10,7 @@ from django.utils import translation
from functools import wraps
from search.search_engine_base import SearchEngine
from request_cache import get_request_or_stub
from .errors import ElasticSearchConnectionError
from .serializers import CourseTeamSerializer, CourseTeam
......@@ -47,7 +48,15 @@ class CourseTeamIndexer(object):
Returns serialized object with additional search fields.
"""
serialized_course_team = CourseTeamSerializer(self.course_team).data
# Django Rest Framework v3.1 requires that we pass the request to the serializer
# so it can construct hyperlinks. To avoid changing the interface of this object,
# we retrieve the request from the request cache.
context = {
"request": get_request_or_stub()
}
serialized_course_team = CourseTeamSerializer(self.course_team, context=context).data
# Save the primary key so we can load the full objects easily after we search
serialized_course_team['pk'] = self.course_team.pk
# Don't save the membership relations in elasticsearch
......
......@@ -4,15 +4,44 @@ from django.contrib.auth.models import User
from django.db.models import Count
from django.conf import settings
from django_countries import countries
from rest_framework import serializers
from openedx.core.lib.api.serializers import CollapsedReferenceSerializer, PaginationSerializer
from openedx.core.lib.api.serializers import CollapsedReferenceSerializer
from openedx.core.lib.api.fields import ExpandableField
from openedx.core.djangoapps.user_api.accounts.serializers import UserReadOnlySerializer
from .models import CourseTeam, CourseTeamMembership
class CountryField(serializers.Field):
"""
Field to serialize a country code.
"""
COUNTRY_CODES = dict(countries).keys()
def to_representation(self, obj):
"""
Represent the country as a 2-character unicode identifier.
"""
return unicode(obj)
def to_internal_value(self, data):
"""
Check that the code is a valid country code.
We leave the data in its original format so that the Django model's
CountryField can convert it to the internal representation used
by the django-countries library.
"""
if data and data not in self.COUNTRY_CODES:
raise serializers.ValidationError(
u"{code} is not a valid country code".format(code=data)
)
return data
class UserMembershipSerializer(serializers.ModelSerializer):
"""Serializes CourseTeamMemberships with only user and date_joined
......@@ -43,6 +72,7 @@ class CourseTeamSerializer(serializers.ModelSerializer):
"""Serializes a CourseTeam with membership information."""
id = serializers.CharField(source='team_id', read_only=True) # pylint: disable=invalid-name
membership = UserMembershipSerializer(many=True, read_only=True)
country = CountryField()
class Meta(object):
"""Defines meta information for the ModelSerializer."""
......@@ -66,6 +96,8 @@ class CourseTeamSerializer(serializers.ModelSerializer):
class CourseTeamCreationSerializer(serializers.ModelSerializer):
"""Deserializes a CourseTeam for creation."""
country = CountryField(required=False)
class Meta(object):
"""Defines meta information for the ModelSerializer."""
model = CourseTeam
......@@ -78,16 +110,17 @@ class CourseTeamCreationSerializer(serializers.ModelSerializer):
"language",
)
def restore_object(self, attrs, instance=None):
"""Restores a CourseTeam instance from the given attrs."""
return CourseTeam.create(
name=attrs.get("name", ''),
course_id=attrs.get("course_id"),
description=attrs.get("description", ''),
topic_id=attrs.get("topic_id", ''),
country=attrs.get("country", ''),
language=attrs.get("language", ''),
def create(self, validated_data):
team = CourseTeam.create(
name=validated_data.get("name", ''),
course_id=validated_data.get("course_id"),
description=validated_data.get("description", ''),
topic_id=validated_data.get("topic_id", ''),
country=validated_data.get("country", ''),
language=validated_data.get("language", ''),
)
team.save()
return team
class CourseTeamSerializerWithoutMembership(CourseTeamSerializer):
......@@ -134,13 +167,6 @@ class MembershipSerializer(serializers.ModelSerializer):
read_only_fields = ("date_joined", "last_activity_at")
class PaginatedMembershipSerializer(PaginationSerializer):
"""Serializes team memberships with support for pagination."""
class Meta(object):
"""Defines meta information for the PaginatedMembershipSerializer."""
object_serializer_class = MembershipSerializer
class BaseTopicSerializer(serializers.Serializer):
"""Serializes a topic without team_count."""
description = serializers.CharField()
......@@ -155,7 +181,7 @@ class TopicSerializer(BaseTopicSerializer):
model to get the count. Requires that `context` is provided with a valid course_id
in order to filter teams within the course.
"""
team_count = serializers.SerializerMethodField('get_team_count')
team_count = serializers.SerializerMethodField()
def get_team_count(self, topic):
"""Get the number of teams associated with this topic"""
......@@ -166,31 +192,25 @@ class TopicSerializer(BaseTopicSerializer):
return CourseTeam.objects.filter(course_id=self.context['course_id'], topic_id=topic['id']).count()
class PaginatedTopicSerializer(PaginationSerializer):
class BulkTeamCountTopicListSerializer(serializers.ListSerializer): # pylint: disable=abstract-method
"""
Serializes a set of topics, adding the team_count field to each topic individually, if team_count
is not already present in the topic data. Requires that `context` is provided with a valid course_id in
order to filter teams within the course.
List serializer for efficiently serializing a set of topics.
"""
class Meta(object):
"""Defines meta information for the PaginatedTopicSerializer."""
object_serializer_class = TopicSerializer
def to_representation(self, obj):
"""Adds team_count to each topic. """
data = super(BulkTeamCountTopicListSerializer, self).to_representation(obj)
add_team_count(data, self.context["course_id"])
return data
class BulkTeamCountPaginatedTopicSerializer(PaginationSerializer):
class BulkTeamCountTopicSerializer(BaseTopicSerializer): # pylint: disable=abstract-method
"""
Serializes a set of topics, adding the team_count field to each topic as a bulk operation per page
(only on the page being returned). Requires that `context` is provided with a valid course_id in
order to filter teams within the course.
Serializes a set of topics, adding the team_count field to each topic as a bulk operation.
Requires that `context` is provided with a valid course_id in order to filter teams within the course.
"""
class Meta(object):
"""Defines meta information for the BulkTeamCountPaginatedTopicSerializer."""
object_serializer_class = BaseTopicSerializer
def __init__(self, *args, **kwargs):
"""Adds team_count to each topic on the current page."""
super(BulkTeamCountPaginatedTopicSerializer, self).__init__(*args, **kwargs)
add_team_count(self.data['results'], self.context['course_id'])
class Meta: # pylint: disable=missing-docstring,old-style-class
list_serializer_class = BulkTeamCountTopicListSerializer
def add_team_count(topics, course_id):
......
......@@ -34,3 +34,10 @@ class CourseTeamMembershipFactory(DjangoModelFactory):
class Meta(object): # pylint: disable=missing-docstring
model = CourseTeamMembership
last_activity_at = LAST_ACTIVITY_AT
@classmethod
def _create(cls, model_class, *args, **kwargs):
"""Create the team membership. """
obj = model_class(*args, **kwargs)
obj.save()
return obj
......@@ -11,12 +11,9 @@ from xmodule.modulestore.tests.factories import CourseFactory
from lms.djangoapps.teams.tests.factories import CourseTeamFactory, CourseTeamMembershipFactory
from lms.djangoapps.teams.serializers import (
BaseTopicSerializer,
PaginatedTopicSerializer,
BulkTeamCountPaginatedTopicSerializer,
BulkTeamCountTopicSerializer,
TopicSerializer,
MembershipSerializer,
add_team_count
)
......@@ -73,21 +70,6 @@ class MembershipSerializerTestCase(SerializerTestCase):
self.assertNotIn('membership', data['team'])
class BaseTopicSerializerTestCase(SerializerTestCase):
"""
Tests for the `BaseTopicSerializer`, which should not serialize team count
data.
"""
def test_team_count_not_included(self):
"""Verifies that the `BaseTopicSerializer` does not include team count"""
with self.assertNumQueries(0):
serializer = BaseTopicSerializer(self.course.teams_topics[0])
self.assertEqual(
serializer.data,
{u'name': u'Tøpic', u'description': u'The bést topic!', u'id': u'0'}
)
class TopicSerializerTestCase(SerializerTestCase):
"""
Tests for the `TopicSerializer`, which should serialize team count data for
......@@ -137,7 +119,7 @@ class TopicSerializerTestCase(SerializerTestCase):
)
class BasePaginatedTopicSerializerTestCase(SerializerTestCase):
class BaseTopicSerializerTestCase(SerializerTestCase):
"""
Base class for testing the two paginated topic serializers.
"""
......@@ -191,13 +173,15 @@ class BasePaginatedTopicSerializerTestCase(SerializerTestCase):
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase):
class BulkTeamCountTopicSerializerTestCase(BaseTopicSerializerTestCase):
"""
Tests for the `BulkTeamCountPaginatedTopicSerializer`, which should serialize team_count
Tests for the `BulkTeamCountTopicSerializer`, which should serialize team_count
data for many topics with constant time SQL queries.
"""
__test__ = True
serializer = BulkTeamCountPaginatedTopicSerializer
serializer = BulkTeamCountTopicSerializer
NUM_TOPICS = 6
def test_topics_with_no_team_counts(self):
"""
......@@ -222,13 +206,13 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer
one SQL query.
"""
teams_per_topic = 10
topics = self.setup_topics(num_topics=self.PAGE_SIZE + 1, teams_per_topic=teams_per_topic)
self.assert_serializer_output(topics[:self.PAGE_SIZE], num_teams_per_topic=teams_per_topic, num_queries=1)
topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=1)
def test_scoped_within_course(self):
"""Verify that team counts are scoped within a course."""
teams_per_topic = 10
first_course_topics = self.setup_topics(num_topics=self.PAGE_SIZE, teams_per_topic=teams_per_topic)
first_course_topics = self.setup_topics(num_topics=self.NUM_TOPICS, teams_per_topic=teams_per_topic)
duplicate_topic = first_course_topics[0]
second_course = CourseFactory.create(
teams_configuration={
......@@ -239,27 +223,44 @@ class BulkTeamCountPaginatedTopicSerializerTestCase(BasePaginatedTopicSerializer
CourseTeamFactory.create(course_id=second_course.id, topic_id=duplicate_topic[u'id'])
self.assert_serializer_output(first_course_topics, num_teams_per_topic=teams_per_topic, num_queries=1)
def _merge_dicts(self, first, second):
"""Convenience method to merge two dicts in a single expression"""
result = first.copy()
result.update(second)
return result
class PaginatedTopicSerializerTestCase(BasePaginatedTopicSerializerTestCase):
"""
Tests for the `PaginatedTopicSerializer`, which will add team_count information per topic if not present.
"""
__test__ = True
serializer = PaginatedTopicSerializer
def setup_topics(self, num_topics=5, teams_per_topic=0):
"""
Helper method to set up topics on the course. Returns a list of
created topics.
"""
self.course.teams_configuration['topics'] = []
topics = [
{u'name': u'Tøpic {}'.format(i), u'description': u'The bést topic! {}'.format(i), u'id': unicode(i)}
for i in xrange(num_topics)
]
for i in xrange(num_topics):
topic_id = unicode(i)
self.course.teams_configuration['topics'].append(topics[i])
for _ in xrange(teams_per_topic):
CourseTeamFactory.create(course_id=self.course.id, topic_id=topic_id)
return topics
def test_topics_with_team_counts(self):
def assert_serializer_output(self, topics, num_teams_per_topic, num_queries):
"""
Verify that we serialize topics with team_count, making one SQL query per topic.
Verify that the serializer produced the expected topics.
"""
teams_per_topic = 2
topics = self.setup_topics(teams_per_topic=teams_per_topic)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=5)
with self.assertNumQueries(num_queries):
serializer = self.serializer(topics, context={'course_id': self.course.id}, many=True)
self.assertEqual(
serializer.data,
[self._merge_dicts(topic, {u'team_count': num_teams_per_topic}) for topic in topics]
)
def test_topics_with_team_counts_prepopulated(self):
def test_no_topics(self):
"""
Verify that if team_count is pre-populated, there are no additional SQL queries.
Verify that we return no results and make no SQL queries for a page
with no topics.
"""
teams_per_topic = 8
topics = self.setup_topics(teams_per_topic=teams_per_topic)
add_team_count(topics, self.course.id)
self.assert_serializer_output(topics, num_teams_per_topic=teams_per_topic, num_queries=0)
self.course.teams_configuration['topics'] = []
self.assert_serializer_output([], num_teams_per_topic=0, num_queries=0)
# -*- coding: utf-8 -*-
"""Tests for the teams API at the HTTP request level."""
import json
import pytz
from datetime import datetime
import pytz
from dateutil import parser
import ddt
from elasticsearch.exceptions import ConnectionError
from mock import patch
from search.search_engine_base import SearchEngine
from django.core.urlresolvers import reverse
from django.conf import settings
from django.db.models.signals import post_save
from django.utils import translation
from nose.plugins.attrib import attr
from rest_framework.test import APITestCase, APIClient
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
from courseware.tests.factories import StaffFactory
from common.test.utils import skip_signal
from student.tests.factories import UserFactory, AdminFactory, CourseEnrollmentFactory
from student.models import CourseEnrollment
from util.testing import EventTestMixin
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
from .factories import CourseTeamFactory, LAST_ACTIVITY_AT
from ..models import CourseTeamMembership
from ..search_indexes import CourseTeamIndexer, CourseTeam, course_team_post_save_callback
from django_comment_common.models import Role, FORUM_ROLE_COMMUNITY_TA
from django_comment_common.utils import seed_permissions_roles
......@@ -36,11 +35,23 @@ class TestDashboard(SharedModuleStoreTestCase):
"""Tests for the Teams dashboard."""
test_password = "test"
NUM_TOPICS = 10
@classmethod
def setUpClass(cls):
super(TestDashboard, cls).setUpClass()
cls.course = CourseFactory.create(
teams_configuration={"max_team_size": 10, "topics": [{"name": "foo", "id": 0, "description": "test topic"}]}
teams_configuration={
"max_team_size": 10,
"topics": [
{
"name": "Topic {}".format(topic_id),
"id": topic_id,
"description": "Description for topic {}".format(topic_id)
}
for topic_id in range(cls.NUM_TOPICS)
]
}
)
def setUp(self):
......@@ -97,6 +108,30 @@ class TestDashboard(SharedModuleStoreTestCase):
response = self.client.get(teams_url)
self.assertEqual(404, response.status_code)
def test_query_counts(self):
# Enroll in the course and log in
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.client.login(username=self.user.username, password=self.test_password)
# Check the query count on the dashboard With no teams
with self.assertNumQueries(15):
self.client.get(self.teams_url)
# Create some teams
for topic_id in range(self.NUM_TOPICS):
team = CourseTeamFactory.create(
name=u"Team for topic {}".format(topic_id),
course_id=self.course.id,
topic_id=topic_id,
)
# Add the user to the last team
team.add_user(self.user)
# Check the query count on the dashboard again
with self.assertNumQueries(19):
self.client.get(self.teams_url)
def test_bad_course_id(self):
"""
Verifies expected behavior when course_id does not reference an existing course or is invalid.
......@@ -252,6 +287,9 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
self.users[user], course.id, check_access=True
)
# Django Rest Framework v3 requires us to pass a request to serializers
# that have URL fields. Since we're invoking this code outside the context
# of a request, we need to simulate that there's a request.
self.solar_team.add_user(self.users['student_enrolled'])
self.nuclear_team.add_user(self.users['student_enrolled_both_courses_other_team'])
self.another_team.add_user(self.users['student_enrolled_both_courses_other_team'])
......@@ -311,7 +349,17 @@ class TeamAPITestCase(APITestCase, SharedModuleStoreTestCase):
response = func(url, data=data, content_type=content_type)
else:
response = func(url, data=data)
self.assertEqual(expected_status, response.status_code)
self.assertEqual(
expected_status,
response.status_code,
msg="Expected status {expected} but got {actual}: {content}".format(
expected=expected_status,
actual=response.status_code,
content=response.content,
)
)
if expected_status == 200:
return json.loads(response.content)
else:
......
......@@ -1866,6 +1866,12 @@ INSTALLED_APPS = (
'provider.oauth2',
'oauth2_provider',
# We don't use this directly (since we use OAuth2), but we need to install it anyway.
# When a user is deleted, Django queries all tables with a FK to the auth_user table,
# and since django-rest-framework-oauth imports this, it will try to access tables
# defined by oauth_provider. If those tables don't exist, an error can occur.
'oauth_provider',
'auth_exchange',
# For the wiki
......@@ -1981,6 +1987,14 @@ INSTALLED_APPS = (
CSRF_COOKIE_AGE = 60 * 60 * 24 * 7 * 52
######################### Django Rest Framework ########################
REST_FRAMEWORK = {
'DEFAULT_PAGINATION_CLASS': 'openedx.core.lib.api.paginators.DefaultPagination',
'PAGE_SIZE': 10,
}
######################### MARKETING SITE ###############################
EDXMKTG_LOGGED_IN_COOKIE_NAME = 'edxloggedin'
EDXMKTG_USER_INFO_COOKIE_NAME = 'edx-user-info'
......
......@@ -9,16 +9,12 @@ from django.conf import settings
# Force settings to run so that the python path is modified
settings.INSTALLED_APPS # pylint: disable=pointless-statement
from instructor.services import InstructorService
from openedx.core.lib.django_startup import autostartup
import edxmako
import logging
from monkey_patch import django_utils_translation
import analytics
from edx_proctoring.runtime import set_runtime_service
from openedx.core.djangoapps.credit.services import CreditService
log = logging.getLogger(__name__)
......@@ -51,6 +47,11 @@ def run():
# right now edx_proctoring is dependent on the openedx.core.djangoapps.credit
# as well as the instructor dashboard (for deleting student attempts)
if settings.FEATURES.get('ENABLE_PROCTORED_EXAMS'):
# Import these here to avoid circular dependencies of the form:
# edx-platform app --> DRF --> django translation --> edx-platform app
from edx_proctoring.runtime import set_runtime_service
from instructor.services import InstructorService
from openedx.core.djangoapps.credit.services import CreditService
set_runtime_service('credit', CreditService())
set_runtime_service('instructor', InstructorService())
......
......@@ -124,4 +124,4 @@ def course_grading_policy(course_key):
final grade.
"""
course = _retrieve_course(course_key)
return GradingPolicySerializer(course.raw_grader).data
return GradingPolicySerializer(course.raw_grader, many=True).data
"""
API Serializers
"""
from collections import defaultdict
from rest_framework import serializers
......@@ -11,23 +13,58 @@ class GradingPolicySerializer(serializers.Serializer):
dropped = serializers.IntegerField(source='drop_count')
weight = serializers.FloatField()
def to_representation(self, obj):
"""
Return a representation of the grading policy.
"""
# Backwards compatibility with the behavior of DRF v2.
# When the grader dictionary was missing keys, DRF v2 would default to None;
# DRF v3 unhelpfully raises an exception.
return dict(
super(GradingPolicySerializer, self).to_representation(
defaultdict(lambda: None, obj)
)
)
# pylint: disable=invalid-name
class BlockSerializer(serializers.Serializer):
""" Serializer for course structure block. """
id = serializers.CharField(source='usage_key')
type = serializers.CharField(source='block_type')
parent = serializers.CharField(source='parent')
parent = serializers.CharField(required=False)
display_name = serializers.CharField()
graded = serializers.BooleanField(default=False)
format = serializers.CharField()
children = serializers.CharField()
def to_representation(self, obj):
"""
Return a representation of the block.
NOTE: this method maintains backwards compatibility with the behavior
of Django Rest Framework v2.
"""
data = super(BlockSerializer, self).to_representation(obj)
# Backwards compatibility with the behavior of DRF v2
# Include a NULL value for "parent" in the representation
# (instead of excluding the key entirely)
if obj.get("parent") is None:
data["parent"] = None
# Backwards compatibility with the behavior of DRF v2
# Leave the children list as a list instead of serializing
# it to a string.
data["children"] = obj["children"]
return data
class CourseStructureSerializer(serializers.Serializer):
""" Serializer for course structure. """
root = serializers.CharField(source='root')
blocks = serializers.SerializerMethodField('get_blocks')
root = serializers.CharField()
blocks = serializers.SerializerMethodField()
def get_blocks(self, structure):
""" Serialize the individual blocks. """
......
......@@ -2,12 +2,33 @@
from rest_framework import serializers
from opaque_keys.edx.keys import CourseKey
from opaque_keys import InvalidKeyError
from openedx.core.djangoapps.credit.models import CreditCourse
class CourseKeyField(serializers.Field):
"""
Serializer field for a model CourseKey field.
"""
def to_representation(self, data):
"""Convert a course key to unicode. """
return unicode(data)
def to_internal_value(self, data):
"""Convert unicode to a course key. """
try:
return CourseKey.from_string(data)
except InvalidKeyError as ex:
raise serializers.ValidationError("Invalid course key: {msg}".format(msg=ex.msg))
class CreditCourseSerializer(serializers.ModelSerializer):
""" CreditCourse Serializer """
course_key = CourseKeyField()
class Meta(object): # pylint: disable=missing-docstring
model = CreditCourse
exclude = ('id',)
......@@ -393,10 +393,7 @@ class CreditCourseViewSetTests(TestCase):
# POSTs without a CSRF token should fail.
response = client.post(self.path, data=json.dumps(data), content_type=JSON)
# NOTE (CCB): Ordinarily we would expect a 403; however, since the CSRF validation and session authentication
# fail, DRF considers the request to be unauthenticated.
self.assertEqual(response.status_code, 401)
self.assertEqual(response.status_code, 403)
self.assertIn('CSRF', response.content)
# Retrieve a CSRF token
......
......@@ -18,8 +18,9 @@ from django.views.decorators.http import require_POST, require_GET
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
import pytz
from rest_framework import viewsets, mixins, permissions, authentication
from rest_framework import viewsets, mixins, permissions
from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from util.json_request import JsonResponse
from util.date_utils import from_timestamp
from openedx.core.djangoapps.credit import api
......@@ -377,17 +378,28 @@ class CreditCourseViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, view
lookup_value_regex = settings.COURSE_KEY_REGEX
queryset = CreditCourse.objects.all()
serializer_class = CreditCourseSerializer
authentication_classes = (authentication.OAuth2Authentication, authentication.SessionAuthentication,)
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
permission_classes = (permissions.IsAuthenticated, permissions.IsAdminUser)
# In Django Rest Framework v3, there is a default pagination
# class that transmutes the response data into a dictionary
# with pagination information. The original response data (a list)
# is stored in a "results" value of the dictionary.
# For backwards compatibility with the existing API, we disable
# the default behavior by setting the pagination_class to None.
pagination_class = None
# This CSRF exemption only applies when authenticating without SessionAuthentication.
# SessionAuthentication will enforce CSRF protection.
@method_decorator(csrf_exempt)
def dispatch(self, request, *args, **kwargs):
# Convert the course ID/key from a string to an actual CourseKey object.
course_id = kwargs.get(self.lookup_field, None)
return super(CreditCourseViewSet, self).dispatch(request, *args, **kwargs)
if course_id:
kwargs[self.lookup_field] = CourseKey.from_string(course_id)
def get_object(self):
# Convert the serialized course key into a CourseKey instance
# so we can look up the object.
course_key = self.kwargs.get(self.lookup_field)
if course_key is not None:
self.kwargs[self.lookup_field] = CourseKey.from_string(course_key)
return super(CreditCourseViewSet, self).dispatch(request, *args, **kwargs)
return super(CreditCourseViewSet, self).get_object()
......@@ -30,6 +30,40 @@ TEST_UPLOAD_DT = datetime.datetime(2002, 1, 9, 15, 43, 01, tzinfo=UTC)
TEST_UPLOAD_DT2 = datetime.datetime(2003, 1, 9, 15, 43, 01, tzinfo=UTC)
class PatchedClient(APIClient):
"""
Patch DRF's APIClient to avoid a unicode error on file upload.
Famous last words: This is a *temporary* fix that we should be
able to remove once we upgrade Django past 1.4.
"""
def request(self, *args, **kwargs):
"""Construct an API request. """
# DRF's default test client implementation uses `six.text_type()`
# to convert the CONTENT_TYPE to `unicode`. In Django 1.4,
# this causes a `UnicodeDecodeError` when Django parses a multipart
# upload.
#
# This is the DRF code we're working around:
# https://github.com/tomchristie/django-rest-framework/blob/3.1.3/rest_framework/compat.py#L227
#
# ... and this is the Django code that raises the exception:
#
# https://github.com/django/django/blob/1.4.22/django/http/multipartparser.py#L435
#
# Django unhelpfully swallows the exception, so to the application code
# it appears as though the user didn't send any file data.
#
# This appears to be an issue only with requests constructed in the test
# suite, not with the upload code used in production.
#
if isinstance(kwargs.get("CONTENT_TYPE"), basestring):
kwargs["CONTENT_TYPE"] = str(kwargs["CONTENT_TYPE"])
return super(PatchedClient, self).request(*args, **kwargs)
class ProfileImageEndpointTestCase(UserSettingsEventTestMixin, APITestCase):
"""
Base class / shared infrastructure for tests of profile_image "upload" and
......@@ -111,6 +145,10 @@ class ProfileImageUploadTestCase(ProfileImageEndpointTestCase):
"""
_view_name = "profile_image_upload"
# Use the patched version of the API client to workaround a unicode issue
# with DRF 3.1 and Django 1.4. Remove this after we upgrade Django past 1.4!
client_class = PatchedClient
def check_upload_event_emitted(self, old=None, new=TEST_UPLOAD_DT):
"""
Make sure we emit a UserProfile event corresponding to the
......
......@@ -183,7 +183,7 @@ def update_account_settings(requesting_user, update, username=None):
serializer.save()
if "language_proficiencies" in update:
new_language_proficiencies = legacy_profile_serializer.data["language_proficiencies"]
new_language_proficiencies = update["language_proficiencies"]
emit_setting_changed_event(
user=existing_user,
db_table=existing_user_profile.language_proficiencies.model._meta.db_table,
......
......@@ -53,7 +53,7 @@ class UserReadOnlySerializer(serializers.Serializer):
super(UserReadOnlySerializer, self).__init__(*args, **kwargs)
def to_native(self, user):
def to_representation(self, user):
"""
Overwrite to_native to handle custom logic since we are serializing two models as one here
:param user: User object
......@@ -152,8 +152,8 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
Class that serializes the portion of UserProfile model needed for account information.
"""
profile_image = serializers.SerializerMethodField("_get_profile_image")
requires_parental_consent = serializers.SerializerMethodField("get_requires_parental_consent")
language_proficiencies = LanguageProficiencySerializer(many=True, allow_add_remove=True, required=False)
requires_parental_consent = serializers.SerializerMethodField()
language_proficiencies = LanguageProficiencySerializer(many=True, required=False)
class Meta(object): # pylint: disable=missing-docstring
model = UserProfile
......@@ -165,25 +165,21 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
read_only_fields = ()
explicit_read_only_fields = ("profile_image", "requires_parental_consent")
def validate_name(self, attrs, source):
def validate_name(self, new_name):
""" Enforce minimum length for name. """
if source in attrs:
new_name = attrs[source].strip()
if len(new_name) < NAME_MIN_LENGTH:
raise serializers.ValidationError(
"The name field must be at least {} characters long.".format(NAME_MIN_LENGTH)
)
attrs[source] = new_name
if len(new_name) < NAME_MIN_LENGTH:
raise serializers.ValidationError(
"The name field must be at least {} characters long.".format(NAME_MIN_LENGTH)
)
return new_name
return attrs
def validate_language_proficiencies(self, attrs, source):
def validate_language_proficiencies(self, value):
""" Enforce all languages are unique. """
language_proficiencies = [language for language in attrs.get(source, [])]
unique_language_proficiencies = set(language.code for language in language_proficiencies)
language_proficiencies = [language for language in value]
unique_language_proficiencies = set(language["code"] for language in language_proficiencies)
if len(language_proficiencies) != len(unique_language_proficiencies):
raise serializers.ValidationError("The language_proficiencies field must consist of unique languages")
return attrs
return value
def transform_gender(self, user_profile, value):
""" Converts empty string to None, to indicate not set. Replaced by to_representation in version 3. """
......@@ -230,3 +226,29 @@ class AccountLegacyProfileSerializer(serializers.HyperlinkedModelSerializer, Rea
call the method with a single argument, the user_profile object.
"""
return AccountLegacyProfileSerializer.get_profile_image(user_profile, user_profile.user)
def update(self, instance, validated_data):
"""
Update the profile, including nested fields.
"""
language_proficiencies = validated_data.pop("language_proficiencies", None)
# Update all fields on the user profile that are writeable,
# except for "language_proficiencies", which we'll update separately
update_fields = set(self.get_writeable_fields()) - set(["language_proficiencies"])
for field_name in update_fields:
default = getattr(instance, field_name)
field_value = validated_data.get(field_name, default)
setattr(instance, field_name, field_value)
instance.save()
# Now update the related language proficiency
if language_proficiencies is not None:
instance.language_proficiencies.all().delete()
instance.language_proficiencies.bulk_create([
LanguageProficiency(user_profile=instance, code=language["code"])
for language in language_proficiencies
])
return instance
......@@ -164,7 +164,10 @@ class TestAccountApi(UserSettingsEventTestMixin, TestCase):
field_errors = context_manager.exception.field_errors
self.assertEqual(3, len(field_errors))
self.assertEqual("This field is not editable via this API", field_errors["username"]["developer_message"])
self.assertIn("Select a valid choice", field_errors["gender"]["developer_message"])
self.assertIn(
"Value \'undecided\' is not valid for field \'gender\'",
field_errors["gender"]["developer_message"]
)
self.assertIn("Valid e-mail address required.", field_errors["email"]["developer_message"])
@patch('django.core.mail.send_mail')
......
......@@ -359,16 +359,19 @@ class TestAccountAPI(UserAPITestCase):
self.assertEqual(404, response.status_code)
@ddt.data(
("gender", "f", "not a gender", u"Select a valid choice. not a gender is not one of the available choices."),
("level_of_education", "none", u"ȻħȺɍłɇs", u"Select a valid choice. ȻħȺɍłɇs is not one of the available choices."),
("country", "GB", "XY", u"Select a valid choice. XY is not one of the available choices."),
("year_of_birth", 2009, "not_an_int", u"Enter a whole number."),
("name", "bob", "z" * 256, u"Ensure this value has at most 255 characters (it has 256)."),
("gender", "f", "not a gender", u'"not a gender" is not a valid choice.'),
("level_of_education", "none", u"ȻħȺɍłɇs", u'"ȻħȺɍłɇs" is not a valid choice.'),
("country", "GB", "XY", u'"XY" is not a valid choice.'),
("year_of_birth", 2009, "not_an_int", u"A valid integer is required."),
("name", "bob", "z" * 256, u"Ensure this field has no more than 255 characters."),
("name", u"ȻħȺɍłɇs", "z ", u"The name field must be at least 2 characters long."),
("goals", "Smell the roses"),
("mailing_address", "Sesame Street"),
# Note that we store the raw data, so it is up to client to escape the HTML.
("bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>", "z" * 3001, u"Ensure this value has at most 3000 characters (it has 3001)."),
(
"bio", u"<html>Lacrosse-playing superhero 壓是進界推日不復女</html>",
"z" * 3001, u"Ensure this field has no more than 3000 characters."
),
# Note that email is tested below, as it is not immediately updated.
# Note that language_proficiencies is tested below as there are multiple error and success conditions.
)
......@@ -568,10 +571,10 @@ class TestAccountAPI(UserAPITestCase):
self.assertItemsEqual(response.data["language_proficiencies"], proficiencies)
@ddt.data(
(u"not_a_list", [{u'non_field_errors': [u'Expected a list of items.']}]),
([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data']}]),
(u"not_a_list", {u'non_field_errors': [u'Expected a list of items but got type "unicode".']}),
([u"not_a_JSON_object"], [{u'non_field_errors': [u'Invalid data. Expected a dictionary, but got unicode.']}]),
([{}], [{"code": [u"This field is required."]}]),
([{u"code": u"invalid_language_code"}], [{'code': [u'Select a valid choice. invalid_language_code is not one of the available choices.']}]),
([{u"code": u"invalid_language_code"}], [{'code': [u'"invalid_language_code" is not a valid choice.']}]),
([{u"code": u"kw"}, {u"code": u"el"}, {u"code": u"kw"}], [u'The language_proficiencies field must consist of unique languages']),
)
@ddt.unpack
......
......@@ -9,9 +9,10 @@ from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.db import IntegrityError
from django.utils.translation import ugettext as _
from student.models import User, UserProfile
from django.utils.translation import ugettext_noop
from student.models import User, UserProfile
from request_cache import get_request_or_stub
from ..errors import (
UserAPIInternalError, UserAPIRequestError, UserNotFound, UserNotAuthorized,
PreferenceValidationError, PreferenceUpdateError
......@@ -68,7 +69,17 @@ def get_user_preferences(requesting_user, username=None):
UserAPIInternalError: the operation failed due to an unexpected error.
"""
existing_user = _get_user(requesting_user, username, allow_staff=True)
user_serializer = UserSerializer(existing_user)
# Django Rest Framework V3 uses the current request to version
# hyperlinked URLS, so we need to retrieve the request and pass
# it in the serializer's context (otherwise we get an AssertionError).
# We're retrieving the request from the cache rather than passing it in
# as an argument because this is an implementation detail of how we're
# serializing data, which we want to encapsulate in the API call.
context = {
"request": get_request_or_stub()
}
user_serializer = UserSerializer(existing_user, context=context)
return user_serializer.data["preferences"]
......@@ -356,7 +367,7 @@ def validate_user_preference_serializer(serializer, preference_key, preference_v
developer_message = u"Value '{preference_value}' not valid for preference '{preference_key}': {error}".format(
preference_key=preference_key, preference_value=preference_value, error=serializer.errors
)
if serializer.errors["key"]:
if "key" in serializer.errors:
user_message = _(u"Invalid user preference key '{preference_key}'.").format(
preference_key=preference_key
)
......
......@@ -403,7 +403,7 @@ def get_expected_validation_developer_message(preference_key, preference_value):
preference_key=preference_key,
preference_value=preference_value,
error={
"key": [u"Ensure this value has at most 255 characters (it has 256)."]
"key": [u"Ensure this field has no more than 255 characters."]
}
)
......
......@@ -6,8 +6,8 @@ from .models import UserPreference
class UserSerializer(serializers.HyperlinkedModelSerializer):
name = serializers.SerializerMethodField("get_name")
preferences = serializers.SerializerMethodField("get_preferences")
name = serializers.SerializerMethodField()
preferences = serializers.SerializerMethodField()
def get_name(self, user):
profile = UserProfile.objects.get(user=user)
......@@ -32,9 +32,10 @@ class UserPreferenceSerializer(serializers.HyperlinkedModelSerializer):
class RawUserPreferenceSerializer(serializers.ModelSerializer):
"""Serializer that generates a raw representation of a user preference.
"""
user = serializers.PrimaryKeyRelatedField()
Serializer that generates a raw representation of a user preference.
"""
user = serializers.PrimaryKeyRelatedField(queryset=User.objects.all())
class Meta(object): # pylint: disable=missing-docstring
model = UserPreference
......@@ -57,3 +58,11 @@ class ReadOnlyFieldsSerializerMixin(object):
cls.Meta.read_only_fields tuple.
"""
return getattr(cls.Meta, 'read_only_fields', '') + getattr(cls.Meta, 'explicit_read_only_fields', '')
@classmethod
def get_writeable_fields(cls):
"""
Return all fields on this serializer that are writeable.
"""
all_fields = getattr(cls.Meta, 'fields', tuple())
return tuple(set(all_fields) - set(cls.get_read_only_fields()))
""" Common Authentication Handlers used across projects. """
from rest_framework import authentication
from rest_framework.authentication import SessionAuthentication
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.compat import oauth2_provider, provider_now
from rest_framework_oauth.compat import oauth2_provider, provider_now
class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthentication):
class SessionAuthenticationAllowInactiveUser(SessionAuthentication):
"""Ensure that the user is logged in, but do not require the account to be active.
We use this in the special case that a user has created an account,
......@@ -51,7 +52,7 @@ class SessionAuthenticationAllowInactiveUser(authentication.SessionAuthenticatio
return (user, None)
class OAuth2AuthenticationAllowInactiveUser(authentication.OAuth2Authentication):
class OAuth2AuthenticationAllowInactiveUser(OAuth2Authentication):
"""
This is a temporary workaround while the is_active field on the user is coupled
with whether or not the user has verified ownership of their claimed email address.
......
"""Fields useful for edX API implementations."""
from django.core.exceptions import ValidationError
from rest_framework.serializers import CharField, Field
from rest_framework.serializers import Field
class ExpandableField(Field):
......@@ -18,25 +16,19 @@ class ExpandableField(Field):
self.expanded = kwargs.pop('expanded_serializer')
super(ExpandableField, self).__init__(**kwargs)
def field_to_native(self, obj, field_name):
"""Converts obj to a native representation, using the expanded serializer if the context requires it."""
if 'expand' in self.context and field_name in self.context['expand']:
self.expanded.initialize(self, field_name)
return self.expanded.field_to_native(obj, field_name)
else:
self.collapsed.initialize(self, field_name)
return self.collapsed.field_to_native(obj, field_name)
def to_representation(self, obj):
"""
Return a representation of the field that is either expanded or collapsed.
"""
should_expand = self.field_name in self.context.get("expand", [])
field = self.expanded if should_expand else self.collapsed
# Avoid double-binding the field, otherwise we'll get
# an error about the source kwarg being redundant.
if field.source is None:
field.bind(self.field_name, self)
class NonEmptyCharField(CharField):
"""
A field that enforces non-emptiness even for partial updates.
if should_expand:
self.expanded.context["expand"] = set(field.context.get("expand", []))
This is necessary because prior to version 3, DRF skips validation for empty
values. Thus, CharField's min_length and RegexField cannot be used to
enforce this constraint.
"""
def validate(self, value):
super(NonEmptyCharField, self).validate(value)
if not value.strip():
raise ValidationError(self.error_messages["required"])
return field.to_representation(obj)
"""
Django Rest Framework view mixins.
"""
from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework import status
from rest_framework.mixins import CreateModelMixin
from rest_framework.response import Response
class PutAsCreateMixin(CreateModelMixin):
"""
Backwards compatibility with Django Rest Framework v2, which allowed
creation of a new resource using PUT.
"""
def update(self, request, *args, **kwargs):
"""
Create/update course modes for a course.
"""
# First, try to update the existing instance
try:
try:
return super(PutAsCreateMixin, self).update(request, *args, **kwargs)
except Http404:
# If no instance exists yet, create it.
# This is backwards-compatible with the behavior of DRF v2.
return super(PutAsCreateMixin, self).create(request, *args, **kwargs)
# Backwards compatibility with DRF v2 behavior, which would catch model-level
# validation errors and return a 400
except ValidationError as err:
return Response(err.messages, status=status.HTTP_400_BAD_REQUEST)
......@@ -3,6 +3,31 @@
from django.http import Http404
from django.core.paginator import Paginator, InvalidPage
from rest_framework.response import Response
from rest_framework import pagination
class DefaultPagination(pagination.PageNumberPagination):
"""
Default paginator for APIs in edx-platform.
This is configured in settings to be automatically used
by any subclass of Django Rest Framework's generic API views.
"""
page_size_query_param = "page_size"
def get_paginated_response(self, data):
"""
Annotate the response with pagination information.
"""
return Response({
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'count': self.page.paginator.count,
'num_pages': self.page.paginator.num_pages,
'results': data
})
def paginate_search_results(object_class, search_results, page_size, page):
"""
......
from rest_framework import pagination, serializers
"""
Serializers to be used in APIs.
"""
class PaginationSerializer(pagination.PaginationSerializer):
"""
Custom PaginationSerializer for openedx.
Adds the following fields:
- num_pages: total number of pages
- current_page: the current page being returned
- start: the index of the first page item within the overall collection
"""
start_page = 1 # django Paginator.page objects have 1-based indexes
num_pages = serializers.Field(source='paginator.num_pages')
current_page = serializers.SerializerMethodField('get_current_page')
start = serializers.SerializerMethodField('get_start')
sort_order = serializers.SerializerMethodField('get_sort_order')
def get_current_page(self, page):
"""Get the current page"""
return page.number
def get_start(self, page):
"""Get the index of the first page item within the overall collection"""
return (self.get_current_page(page) - self.start_page) * page.paginator.per_page
def get_sort_order(self, page): # pylint: disable=unused-argument
"""Get the order by which this collection was sorted"""
return self.context.get('sort_order')
from rest_framework import serializers
class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer):
......@@ -54,9 +30,10 @@ class CollapsedReferenceSerializer(serializers.HyperlinkedModelSerializer):
super(CollapsedReferenceSerializer, self).__init__(*args, **kwargs)
self.fields[id_source] = serializers.CharField(read_only=True, source=id_source)
self.fields[id_source] = serializers.CharField(read_only=True)
self.fields['url'].view_name = view_name
self.fields['url'].lookup_field = lookup_field
self.fields['url'].lookup_url_kwarg = lookup_field
class Meta(object):
"""Defines meta information for the ModelSerializer.
......
......@@ -9,6 +9,7 @@ from django.utils.translation import ugettext as _
from rest_framework import status, response
from rest_framework.exceptions import APIException
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import clone_request
from rest_framework.response import Response
from rest_framework.mixins import RetrieveModelMixin, UpdateModelMixin
from rest_framework.generics import GenericAPIView
......@@ -193,3 +194,23 @@ class RetrievePatchAPIView(RetrieveModelMixin, UpdateModelMixin, GenericAPIView)
add_serializer_errors(serializer, patch, field_errors)
return field_errors
def get_object_or_none(self):
"""
Retrieve an object or return None if the object can't be found.
NOTE: This replaces functionality that was removed in Django Rest Framework v3.1.
"""
try:
return self.get_object()
except Http404:
if self.request.method == 'PUT':
# For PUT-as-create operation, we need to ensure that we have
# relevant permissions, as if this was a POST request. This
# will either raise a PermissionDenied exception, or simply
# return None.
self.check_permissions(clone_request(self.request, 'POST'))
else:
# PATCH requests where the object does not exist should still
# return a 404 response.
raise
......@@ -28,7 +28,7 @@ django-ses==0.7.0
django-simple-history==1.6.3
django-storages==1.1.5
django-method-override==0.1.0
djangorestframework==2.3.14
djangorestframework>=3.1,<3.2
django==1.4.22
elasticsearch==0.4.5
facebook-sdk==0.4.0
......
......@@ -12,6 +12,7 @@ git+https://github.com/edx/django-staticfiles.git@031bdeaea85798b8c284e2a09977df
-e git+https://github.com/edx/django-pipeline.git@88ec8a011e481918fdc9d2682d4017c835acd8be#egg=django-pipeline
-e git+https://github.com/edx/django-wiki.git@cd0b2b31997afccde519fe5b3365e61a9edb143f#egg=django-wiki
-e git+https://github.com/edx/django-oauth2-provider.git@0.2.7-fork-edx-5#egg=django-oauth2-provider
-e git+https://github.com/edx/django-rest-framework-oauth.git@f0b503fda8c254a38f97fef802ded4f5fe367f7a#egg=djangorestframework-oauth
-e git+https://github.com/edx/MongoDBProxy.git@25b99097615bda06bd7cdfe5669ed80dc2a7fed0#egg=mongodb_proxy
git+https://github.com/edx/nltk.git@2.0.6#egg=nltk==2.0.6
-e git+https://github.com/dementrock/pystache_custom.git@776973740bdaad83a3b029f96e415a7d1e8bec2f#egg=pystache_custom-dev
......@@ -40,13 +41,13 @@ git+https://github.com/edx/rfc6266.git@v0.0.5-edx#egg=rfc6266==0.0.5-edx
-e git+https://github.com/edx/event-tracking.git@0.2.0#egg=event-tracking
-e git+https://github.com/edx-solutions/django-splash.git@7579d052afcf474ece1239153cffe1c89935bc4f#egg=django-splash
-e git+https://github.com/edx/acid-block.git@e46f9cda8a03e121a00c7e347084d142d22ebfb7#egg=acid-xblock
-e git+https://github.com/edx/edx-ora2.git@release-2015-08-25T16.16#egg=edx-ora2
-e git+https://github.com/edx/edx-submissions.git@9538ee8a971d04dc1cb05e88f6aa0c36b224455c#egg=edx-submissions
-e git+https://github.com/edx/edx-ora2.git@release-2015-09-16T15.28#egg=edx-ora2
-e git+https://github.com/edx/edx-submissions.git@0.1.0#egg=edx-submissions
-e git+https://github.com/edx/opaque-keys.git@27dc382ea587483b1e3889a3d19cbd90b9023a06#egg=opaque-keys
git+https://github.com/edx/ease.git@release-2015-07-14#egg=ease==0.1.3
git+https://github.com/edx/i18n-tools.git@v0.1.3#egg=i18n-tools==v0.1.3
git+https://github.com/edx/edx-oauth2-provider.git@0.5.7#egg=oauth2-provider==0.5.7
-e git+https://github.com/edx/edx-val.git@v0.0.5#egg=edx-val
-e git+https://github.com/edx/edx-val.git@0.0.6#egg=edx-val
-e git+https://github.com/pmitros/RecommenderXBlock.git@518234bc354edbfc2651b9e534ddb54f96080779#egg=recommender-xblock
-e git+https://github.com/edx/edx-search.git@release-2015-09-11a#egg=edx-search
-e git+https://github.com/edx/edx-milestones.git@9b44a37edc3d63a23823c21a63cdd53ef47a7aa4#egg=edx-milestones
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment