Commit 871a50b2 by Clinton Blackburn

Merge pull request #105 from edx/clintonb/catalog-create-fix

Updated catalog viewer creation logic
parents 248ec350 9794bb60
from urllib.parse import urlencode from urllib.parse import urlencode
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db import transaction
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
...@@ -96,27 +95,10 @@ class CatalogSerializer(serializers.ModelSerializer): ...@@ -96,27 +95,10 @@ class CatalogSerializer(serializers.ModelSerializer):
allow_null=True, allow_empty=True, required=False, allow_null=True, allow_empty=True, required=False,
help_text=_('Usernames of users with explicit access to view this catalog')) help_text=_('Usernames of users with explicit access to view this catalog'))
def is_valid(self, **kwargs):
# Ensure that the catalog's viewers actually exist in the
# DB. We keep this in a transaction so that users are only
# created if the data is valid.
sid = transaction.savepoint()
for username in self.initial_data.get('viewers', ()): # pylint: disable=no-member
User.objects.get_or_create(username=username)
if super().is_valid(**kwargs):
# Data is good; commit the transaction.
transaction.savepoint_commit(sid)
return True
else:
# Invalid data; roll back the user creation.
transaction.savepoint_rollback(sid)
return False
def create(self, validated_data): def create(self, validated_data):
viewers = set() viewers = validated_data.pop('viewers')
for username in validated_data.pop('viewers'): viewers = User.objects.filter(username__in=viewers)
user = User.objects.get(username=username)
viewers.add(user)
# Set viewers after the model has been saved # Set viewers after the model has been saved
instance = super(CatalogSerializer, self).create(validated_data) instance = super(CatalogSerializer, self).create(validated_data)
instance.viewers = viewers instance.viewers = viewers
......
...@@ -39,21 +39,6 @@ class CatalogSerializerTests(TestCase): ...@@ -39,21 +39,6 @@ class CatalogSerializerTests(TestCase):
} }
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
def test_create_new_user(self):
username = 'test-user'
data = {
'viewers': [username],
'id': None,
'name': 'test new catalog',
'query': '*',
}
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
serializer = CatalogSerializer(data=data)
self.assertTrue(serializer.is_valid())
catalog = serializer.save()
self.assertEqual([viewer.username for viewer in catalog.viewers], [username])
self.assertEqual(User.objects.filter(username=username).count(), 1) # pylint: disable=no-member
def test_invalid_data_user_create(self): def test_invalid_data_user_create(self):
"""Verify that users are not created if the serializer data is invalid.""" """Verify that users are not created if the serializer data is invalid."""
username = 'test-user' username = 'test-user'
......
...@@ -3,6 +3,7 @@ import datetime ...@@ -3,6 +3,7 @@ import datetime
import urllib import urllib
import ddt import ddt
from django.contrib.auth import get_user_model
import pytz import pytz
import responses import responses
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
...@@ -16,6 +17,8 @@ from course_discovery.apps.core.tests.factories import UserFactory ...@@ -16,6 +17,8 @@ from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory
User = get_user_model()
@ddt.ddt @ddt.ddt
class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixin, APITestCase): class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixin, APITestCase):
...@@ -96,6 +99,38 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -96,6 +99,38 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
self.mock_user_info_response(self.user) self.mock_user_info_response(self.user)
self.assert_catalog_created(HTTP_AUTHORIZATION=self.generate_oauth2_token_header(self.user)) self.assert_catalog_created(HTTP_AUTHORIZATION=self.generate_oauth2_token_header(self.user))
def test_create_with_new_user(self):
""" Verify that new users are created if the list of viewers includes the usernames of non-existent users. """
new_viewer_username = 'new-guy'
existing_viewer = UserFactory()
viewers = [new_viewer_username, existing_viewer.username]
data = {
'name': 'Test Catalog',
'query': '*:*',
'viewers': ','.join(viewers)
}
# NOTE: We explicitly avoid using the JSON data type so that we properly test string parsing.
response = self.client.post(self.catalog_list_url, data)
self.assertEqual(response.status_code, 201)
catalog = Catalog.objects.latest()
latest_user = User.objects.latest()
self.assertEqual(latest_user.username, new_viewer_username)
self.assertListEqual(list(catalog.viewers), [existing_viewer, latest_user])
def test_create_with_new_user_error(self):
""" Verify no users are created if an error occurs while processing a create request. """
# The missing name and query fields should trigger an error
data = {
'viewers': ['new-guy']
}
original_user_count = User.objects.count()
response = self.client.post(self.catalog_list_url, data)
self.assertEqual(response.status_code, 400)
self.assertEqual(User.objects.count(), original_user_count)
def test_courses(self): def test_courses(self):
""" Verify the endpoint returns the list of courses contained in the catalog. """ """ Verify the endpoint returns the list of courses contained in the catalog. """
url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id}) url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id})
......
...@@ -2,7 +2,9 @@ import logging ...@@ -2,7 +2,9 @@ import logging
import os import os
from io import StringIO from io import StringIO
from django.contrib.auth import get_user_model
from django.core.management import call_command from django.core.management import call_command
from django.db import transaction
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from dry_rest_permissions.generics import DRYPermissions from dry_rest_permissions.generics import DRYPermissions
...@@ -14,17 +16,18 @@ from rest_framework.permissions import IsAuthenticated ...@@ -14,17 +16,18 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from course_discovery.apps.api.filters import PermissionsFilter from course_discovery.apps.api.filters import PermissionsFilter
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer
from course_discovery.apps.api.serializers import ( from course_discovery.apps.api.serializers import (
CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer, CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer,
CourseSerializerExcludingClosedRuns, AffiliateWindowSerializer, ContainedCourseRunsSerializer CourseSerializerExcludingClosedRuns, AffiliateWindowSerializer, ContainedCourseRunsSerializer
) )
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.utils import SearchQuerySetWrapper from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import Course, CourseRun, Seat from course_discovery.apps.course_metadata.models import Course, CourseRun, Seat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model()
# pylint: disable=no-member # pylint: disable=no-member
...@@ -37,10 +40,26 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -37,10 +40,26 @@ class CatalogViewSet(viewsets.ModelViewSet):
queryset = Catalog.objects.all() queryset = Catalog.objects.all()
serializer_class = CatalogSerializer serializer_class = CatalogSerializer
# The boilerplate methods are required to be recognized by swagger @transaction.atomic
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
""" Create a new catalog. """ """ Create a new catalog. """
return super(CatalogViewSet, self).create(request, *args, **kwargs) data = request.data.copy()
usernames = request.data.get('viewers', ())
# Add support for parsing a comma-separated list from Swagger
if isinstance(usernames, str):
usernames = usernames.split(',')
data.setlist('viewers', usernames)
# Ensure the users exist
for username in usernames:
User.objects.get_or_create(username=username)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
""" Destroy a catalog. """ """ Destroy a catalog. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment