Commit 2d308c9b by Clinton Blackburn Committed by Peter Fogg

Updated catalog viewer creation logic

The logic for creating viewers has been moved to the view. This allows us to correct the data passed to the serializer when the viewers field is set to a comma-delimited list (e.g. for Swagger).

 ECOM-4489
parent 37df16bc
from urllib.parse import urlencode from urllib.parse import urlencode
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db import transaction
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
...@@ -96,27 +95,10 @@ class CatalogSerializer(serializers.ModelSerializer): ...@@ -96,27 +95,10 @@ class CatalogSerializer(serializers.ModelSerializer):
allow_null=True, allow_empty=True, required=False, allow_null=True, allow_empty=True, required=False,
help_text=_('Usernames of users with explicit access to view this catalog')) help_text=_('Usernames of users with explicit access to view this catalog'))
def is_valid(self, **kwargs):
# Ensure that the catalog's viewers actually exist in the
# DB. We keep this in a transaction so that users are only
# created if the data is valid.
sid = transaction.savepoint()
for username in self.initial_data.get('viewers', ()): # pylint: disable=no-member
User.objects.get_or_create(username=username)
if super().is_valid(**kwargs):
# Data is good; commit the transaction.
transaction.savepoint_commit(sid)
return True
else:
# Invalid data; roll back the user creation.
transaction.savepoint_rollback(sid)
return False
def create(self, validated_data): def create(self, validated_data):
viewers = set() viewers = validated_data.pop('viewers')
for username in validated_data.pop('viewers'): viewers = User.objects.filter(username__in=viewers)
user = User.objects.get(username=username)
viewers.add(user)
# Set viewers after the model has been saved # Set viewers after the model has been saved
instance = super(CatalogSerializer, self).create(validated_data) instance = super(CatalogSerializer, self).create(validated_data)
instance.viewers = viewers instance.viewers = viewers
......
...@@ -39,21 +39,6 @@ class CatalogSerializerTests(TestCase): ...@@ -39,21 +39,6 @@ class CatalogSerializerTests(TestCase):
} }
self.assertDictEqual(serializer.data, expected) self.assertDictEqual(serializer.data, expected)
def test_create_new_user(self):
username = 'test-user'
data = {
'viewers': [username],
'id': None,
'name': 'test new catalog',
'query': '*',
}
self.assertEqual(User.objects.filter(username=username).count(), 0) # pylint: disable=no-member
serializer = CatalogSerializer(data=data)
self.assertTrue(serializer.is_valid())
catalog = serializer.save()
self.assertEqual([viewer.username for viewer in catalog.viewers], [username])
self.assertEqual(User.objects.filter(username=username).count(), 1) # pylint: disable=no-member
def test_invalid_data_user_create(self): def test_invalid_data_user_create(self):
"""Verify that users are not created if the serializer data is invalid.""" """Verify that users are not created if the serializer data is invalid."""
username = 'test-user' username = 'test-user'
......
...@@ -3,6 +3,7 @@ import datetime ...@@ -3,6 +3,7 @@ import datetime
import urllib import urllib
import ddt import ddt
from django.contrib.auth import get_user_model
import pytz import pytz
import responses import responses
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
...@@ -16,6 +17,8 @@ from course_discovery.apps.core.tests.factories import UserFactory ...@@ -16,6 +17,8 @@ from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory
User = get_user_model()
@ddt.ddt @ddt.ddt
class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixin, APITestCase): class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixin, APITestCase):
...@@ -96,6 +99,38 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi ...@@ -96,6 +99,38 @@ class CatalogViewSetTests(ElasticsearchTestMixin, SerializationMixin, OAuth2Mixi
self.mock_user_info_response(self.user) self.mock_user_info_response(self.user)
self.assert_catalog_created(HTTP_AUTHORIZATION=self.generate_oauth2_token_header(self.user)) self.assert_catalog_created(HTTP_AUTHORIZATION=self.generate_oauth2_token_header(self.user))
def test_create_with_new_user(self):
""" Verify that new users are created if the list of viewers includes the usernames of non-existent users. """
new_viewer_username = 'new-guy'
existing_viewer = UserFactory()
viewers = [new_viewer_username, existing_viewer.username]
data = {
'name': 'Test Catalog',
'query': '*:*',
'viewers': ','.join(viewers)
}
# NOTE: We explicitly avoid using the JSON data type so that we properly test string parsing.
response = self.client.post(self.catalog_list_url, data)
self.assertEqual(response.status_code, 201)
catalog = Catalog.objects.latest()
latest_user = User.objects.latest()
self.assertEqual(latest_user.username, new_viewer_username)
self.assertListEqual(list(catalog.viewers), [existing_viewer, latest_user])
def test_create_with_new_user_error(self):
""" Verify no users are created if an error occurs while processing a create request. """
# The missing name and query fields should trigger an error
data = {
'viewers': ['new-guy']
}
original_user_count = User.objects.count()
response = self.client.post(self.catalog_list_url, data)
self.assertEqual(response.status_code, 400)
self.assertEqual(User.objects.count(), original_user_count)
def test_courses(self): def test_courses(self):
""" Verify the endpoint returns the list of courses contained in the catalog. """ """ Verify the endpoint returns the list of courses contained in the catalog. """
url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id}) url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id})
......
...@@ -2,7 +2,9 @@ import logging ...@@ -2,7 +2,9 @@ import logging
import os import os
from io import StringIO from io import StringIO
from django.contrib.auth import get_user_model
from django.core.management import call_command from django.core.management import call_command
from django.db import transaction
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from dry_rest_permissions.generics import DRYPermissions from dry_rest_permissions.generics import DRYPermissions
...@@ -14,17 +16,18 @@ from rest_framework.permissions import IsAuthenticated ...@@ -14,17 +16,18 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from course_discovery.apps.api.filters import PermissionsFilter from course_discovery.apps.api.filters import PermissionsFilter
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer
from course_discovery.apps.api.serializers import ( from course_discovery.apps.api.serializers import (
CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer, CatalogSerializer, CourseSerializer, CourseRunSerializer, ContainedCoursesSerializer,
CourseSerializerExcludingClosedRuns, AffiliateWindowSerializer, ContainedCourseRunsSerializer CourseSerializerExcludingClosedRuns, AffiliateWindowSerializer, ContainedCourseRunsSerializer
) )
from course_discovery.apps.api.renderers import AffiliateWindowXMLRenderer
from course_discovery.apps.catalogs.models import Catalog from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.utils import SearchQuerySetWrapper from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX from course_discovery.apps.course_metadata.constants import COURSE_ID_REGEX, COURSE_RUN_ID_REGEX
from course_discovery.apps.course_metadata.models import Course, CourseRun, Seat from course_discovery.apps.course_metadata.models import Course, CourseRun, Seat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
User = get_user_model()
# pylint: disable=no-member # pylint: disable=no-member
...@@ -37,10 +40,26 @@ class CatalogViewSet(viewsets.ModelViewSet): ...@@ -37,10 +40,26 @@ class CatalogViewSet(viewsets.ModelViewSet):
queryset = Catalog.objects.all() queryset = Catalog.objects.all()
serializer_class = CatalogSerializer serializer_class = CatalogSerializer
# The boilerplate methods are required to be recognized by swagger @transaction.atomic
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
""" Create a new catalog. """ """ Create a new catalog. """
return super(CatalogViewSet, self).create(request, *args, **kwargs) data = request.data.copy()
usernames = request.data.get('viewers', ())
# Add support for parsing a comma-separated list from Swagger
if isinstance(usernames, str):
usernames = usernames.split(',')
data.setlist('viewers', usernames)
# Ensure the users exist
for username in usernames:
User.objects.get_or_create(username=username)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
""" Destroy a catalog. """ """ Destroy a catalog. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment