Commit 2d308c9b by Clinton Blackburn Committed by Peter Fogg

Updated catalog viewer creation logic

The logic for creating viewers has been moved to the view. This allows us to correct the data passed to the serializer when the viewers field is set to a comma-delimited list (e.g. for Swagger).

 ECOM-4489
parent 37df16bc
from urllib.parse import urlencode
from django.contrib.auth import get_user_model
from django.db import transaction
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
......@@ -96,27 +95,10 @@ class CatalogSerializer(serializers.ModelSerializer):
allow_null=True, allow_empty=True, required=False,
help_text=_('Usernames of users with explicit access to view this catalog'))
def is_valid(self, **kwargs):
# Ensure that the catalog's viewers actually exist in the
# DB. We keep this in a transaction so that users are only
# created if the data is valid.
sid = transaction.savepoint()
for username in self.initial_data.get('viewers', ()): # pylint: disable=no-member
User.objects.get_or_create(username=username)
if super().is_valid(**kwargs):
# Data is good; commit the transaction.
transaction.savepoint_commit(sid)
return True
else:
# Invalid data; roll back the user creation.
transaction.savepoint_rollback(sid)
return False
def create(self, validated_data):
viewers = set()
for username in validated_data.pop('viewers'):
user = User.objects.get(username=username)
viewers.add(user)
viewers = validated_data.pop('viewers')
viewers = User.objects.filter(username__in=viewers)
# Set viewers after the model has been saved
instance = super(CatalogSerializer, self).create(validated_data)
instance.viewers = viewers
......
......@@ -39,21 +39,6 @@ class CatalogSerializerTests(TestCase):
}
self.assertDictEqual(serializer.data, expected)
def test_create_new_user(self):
username = 'test-user'
data = {
'viewers': [username],
'id': None,
'name': 'test new catalog',
'query': '*',
}
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
serializer = CatalogSerializer(data=data)
self.assertTrue(serializer.is_valid())
catalog = serializer.save()
self.assertEqual([viewer.username for viewer in catalog.viewers], [username])
self.assertEqual(User.objects.filter(username=username).count(), 1) # pylint: disable=no-member
def test_invalid_data_user_create(self):
"""Verify that users are not created if the serializer data is invalid."""
username = 'test-user'
......
......@@ -3,6 +3,7 @@ import datetime
import urllib
import ddt
from django.contrib.auth import get_user_model
import pytz
import responses
from rest_framework.reverse import reverse
......@@ -16,6 +17,8 @@ from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory
User = get_user_model()
@ddt.ddt
class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixin, APITestCase):
......@@ -96,6 +99,38 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
self.mock_user_info_response(self.user)
self.assert_catalog_created(HTTP_AUTHORIZATION=self.generate_oauth2_token_header(self.user))
def test_create_with_new_user(self):
""" Verify that new users are created if the list of viewers includes the usernames of non-existent users. """
new_viewer_username = 'new-guy'
existing_viewer = UserFactory()
viewers = [new_viewer_username, existing_viewer.username]
data = {
'name': 'Test Catalog',
'query': '*:*',
'viewers': ','.join(viewers)
}
# NOTE: We explicitly avoid using the JSON data type so that we properly test string parsing.
response = self.client.post(self.catalog_list_url, data)
self.assertEqual(response.status_code, 201)
catalog = Catalog.objects.latest()
latest_user = User.objects.latest()
self.assertEqual(latest_user.username, new_viewer_username)
self.assertListEqual(list(catalog.viewers), [existing_viewer, latest_user])
def test_create_with_new_user_error(self):
""" Verify no users are created if an error occurs while processing a create request. """
# The missing name and query fields should trigger an error
data = {
'viewers': ['new-guy']
}
original_user_count = User.objects.count()
response = self.client.post(self.catalog_list_url, data)
self.assertEqual(response.status_code, 400)
self.assertEqual(User.objects.count(), original_user_count)
def test_courses(self):
""" Verify the endpoint returns the list of courses contained in the catalog. """
url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id})
......
......@@ -2,7 +2,9 @@ import logging
import os
from io import StringIO
from django.contrib.auth import get_user_model
from django.core.management import call_command
from django.db import transaction
from django.db.models.functions import Lower
from django.shortcuts import get_object_or_404
from dry_rest_permissions.generics import DRYPermissions
......@@ -14,17 +16,18 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from course_discovery.apps.api.filters import PermissionsFilter
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer
from course_discovery.apps.api.serializers import (
CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer,
CourseSerializerExcludingClosedRuns, AffiliateWindowSerializer, ContainedCourseRunsSerializer
)
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer
from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import Course, CourseRun, Seat
logger = logging.getLogger(__name__)
User = get_user_model()
# pylint: disable=no-member
......@@ -37,10 +40,26 @@ class CatalogViewSet(viewsets.ModelViewSet):
queryset = Catalog.objects.all()
serializer_class = CatalogSerializer
# The boilerplate methods are required to be recognized by swagger
@transaction.atomic
def create(self, request, *args, **kwargs):
""" Create a new catalog. """
return super(CatalogViewSet, self).create(request, *args, **kwargs)
data = request.data.copy()
usernames = request.data.get('viewers', ())
# Add support for parsing a comma-separated list from Swagger
if isinstance(usernames, str):
usernames = usernames.split(',')
data.setlist('viewers', usernames)
# Ensure the users exist
for username in usernames:
User.objects.get_or_create(username=username)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def destroy(self, request, *args, **kwargs):
""" Destroy a catalog. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment