Commit d7e83538 by Michael Terry Committed by Michael Terry

Support creating entitlements

Allow creating entitlement products via the AtomicPublication API.
Just pass them in like normal seat products, with an entitlement
product class instead of a seat class.

And remove the unused /products create endpoint (that had been
intended to support this workflow, but it's not atomic, so we'll
keep using the generic publication endpoint with new support for
entitlements and drop the unused create-one-entitlement endpoint).

LEARNER-3891
parent 2a281d0a
......@@ -22,8 +22,15 @@ def create_parent_course_entitlement(name, UUID):
parent, created = Product.objects.get_or_create(
structure=Product.PARENT,
product_class=ProductClass.objects.get(name=COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME),
title='Parent Course Entitlement for {}'.format(name),
attributes__name='UUID',
attribute_values__value_text=UUID,
defaults={
'title': 'Parent Course Entitlement for {}'.format(name),
'is_discountable': True,
},
)
parent.attr.UUID = UUID
parent.attr.save()
if created:
logger.debug('Created new parent course_entitlement [%d] for [%s].', parent.id, UUID)
......@@ -31,10 +38,6 @@ def create_parent_course_entitlement(name, UUID):
logger.debug('Parent course_entitlement [%d] already exists for [%s].', parent.id, UUID)
ProductCategory.objects.get_or_create(category=Category.objects.get(name='Course Entitlements'), product=parent)
parent.title = 'Parent Course Entitlement for {}'.format(name)
parent.is_discountable = True
parent.attr.UUID = UUID
parent.save()
return parent, created
......@@ -45,8 +48,11 @@ def create_or_update_course_entitlement(certificate_type, price, partner, UUID,
certificate_type = certificate_type.lower()
UUID = unicode(UUID)
uuid_query = Q(
attributes__name='UUID',
attribute_values__value_text=UUID,
)
certificate_type_query = Q(
title='Course {}'.format(name),
attributes__name='certificate_type',
attribute_values__value_text=certificate_type,
)
......@@ -54,7 +60,7 @@ def create_or_update_course_entitlement(certificate_type, price, partner, UUID,
try:
parent_entitlement, __ = create_parent_course_entitlement(name, UUID)
all_products = parent_entitlement.children.all().prefetch_related('stockrecords')
course_entitlement = all_products.get(certificate_type_query)
course_entitlement = all_products.filter(uuid_query).get(certificate_type_query)
except Product.DoesNotExist:
course_entitlement = Product()
......
......@@ -14,9 +14,12 @@ from oscar.core.loading import get_class, get_model
from rest_framework import serializers
from rest_framework.reverse import reverse
from ecommerce.core.constants import COURSE_ID_REGEX, ENROLLMENT_CODE_SWITCH, ISO_8601_FORMAT, SEAT_PRODUCT_CLASS_NAME
from ecommerce.core.constants import (COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME, COURSE_ID_REGEX, ENROLLMENT_CODE_SWITCH,
ISO_8601_FORMAT, SEAT_PRODUCT_CLASS_NAME)
from ecommerce.core.url_utils import get_ecommerce_url
from ecommerce.courses.models import Course
from ecommerce.courses.utils import get_course_info_from_catalog
from ecommerce.entitlements.utils import create_or_update_course_entitlement
from ecommerce.invoice.models import Invoice
logger = logging.getLogger(__name__)
......@@ -97,6 +100,11 @@ def retrieve_voucher_usage(obj):
return retrieve_voucher(obj).usage
def _flatten(attrs):
"""Transform a list of attribute names and values into a dictionary keyed on the names."""
return {attr['name']: attr['value'] for attr in attrs}
class ProductPaymentInfoMixin(serializers.ModelSerializer):
""" Mixin class used for retrieving price information from products. """
price = serializers.SerializerMethodField()
......@@ -317,6 +325,73 @@ class CourseSerializer(serializers.HyperlinkedModelSerializer):
}
class EntitlementProductHelper(object):
@staticmethod
def validate(product):
attrs = _flatten(product['attribute_values'])
if 'certificate_type' not in attrs:
raise serializers.ValidationError(_(u"Products must have a certificate type."))
if 'price' not in product:
raise serializers.ValidationError(_(u"Products must have a price."))
@staticmethod
def save(partner, course, uuid, product):
attrs = _flatten(product['attribute_values'])
# Extract arguments required for Seat creation, deserializing as necessary.
certificate_type = attrs.get('certificate_type')
price = Decimal(product['price'])
create_or_update_course_entitlement(
certificate_type,
price,
partner,
uuid,
course.name
)
class SeatProductHelper(object):
@staticmethod
def validate(product):
attrs = _flatten(product['attribute_values'])
if attrs.get('id_verification_required') is None:
raise serializers.ValidationError(_(u"Products must indicate whether ID verification is required."))
# Verify that a price is present.
if product.get('price') is None:
raise serializers.ValidationError(_(u"Products must have a price."))
@staticmethod
def save(partner, course, product, create_enrollment_code):
attrs = _flatten(product['attribute_values'])
# Extract arguments required for Seat creation, deserializing as necessary.
certificate_type = attrs.get('certificate_type', '')
id_verification_required = attrs['id_verification_required']
price = Decimal(product['price'])
# Extract arguments which are optional for Seat creation, deserializing as necessary.
expires = product.get('expires')
expires = parse(expires) if expires else None
credit_provider = attrs.get('credit_provider')
credit_hours = attrs.get('credit_hours')
credit_hours = int(credit_hours) if credit_hours else None
course.create_or_update_seat(
certificate_type,
id_verification_required,
price,
partner,
expires=expires,
credit_provider=credit_provider,
credit_hours=credit_hours,
create_enrollment_code=create_enrollment_code
)
class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""Serializer for saving and publishing a Course and associated products.
......@@ -337,22 +412,17 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab
def validate_products(self, products):
"""Validate product data."""
for product in products:
# Verify that each product is intended to be a Seat.
product_class = product.get('product_class')
if product_class != SEAT_PRODUCT_CLASS_NAME:
if product_class == COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME:
EntitlementProductHelper.validate(product)
elif product_class == SEAT_PRODUCT_CLASS_NAME:
SeatProductHelper.validate(product)
else:
raise serializers.ValidationError(
_(u"Invalid product class [{product_class}] requested.".format(product_class=product_class))
_(u"Invalid product class [{product_class}] requested.").format(product_class=product_class)
)
# Verify that attributes required to create a Seat are present.
attrs = self._flatten(product['attribute_values'])
if attrs.get('id_verification_required') is None:
raise serializers.ValidationError(_(u"Products must indicate whether ID verification is required."))
# Verify that a price is present.
if product.get('price') is None:
raise serializers.ValidationError(_(u"Products must have a price."))
return products
def get_partner(self):
......@@ -395,34 +465,20 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab
course.verification_deadline = course_verification_deadline
course.save()
# Fetch full course info (from seat product, because this queries based on attr.course_key)
course_info = get_course_info_from_catalog(course.site, course.parent_seat_product)
create_enrollment_code = False
if waffle.switch_is_active(ENROLLMENT_CODE_SWITCH) and site.siteconfiguration.enable_enrollment_codes:
create_enrollment_code = create_or_activate_enrollment_code
for product in products:
attrs = self._flatten(product['attribute_values'])
# Extract arguments required for Seat creation, deserializing as necessary.
certificate_type = attrs.get('certificate_type', '')
id_verification_required = attrs['id_verification_required']
price = Decimal(product['price'])
# Extract arguments which are optional for Seat creation, deserializing as necessary.
expires = product.get('expires')
expires = parse(expires) if expires else None
credit_provider = attrs.get('credit_provider')
credit_hours = attrs.get('credit_hours')
credit_hours = int(credit_hours) if credit_hours else None
course.create_or_update_seat(
certificate_type,
id_verification_required,
price,
partner,
expires=expires,
credit_provider=credit_provider,
credit_hours=credit_hours,
create_enrollment_code=create_enrollment_code
)
product_class = product.get('product_class')
if product_class == COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME:
EntitlementProductHelper.save(partner, course, course_info['UUID'], product)
elif product_class == SEAT_PRODUCT_CLASS_NAME:
SeatProductHelper.save(partner, course, product, create_enrollment_code)
if course.get_enrollment_code():
course.toggle_enrollment_code_status(is_active=create_enrollment_code)
......@@ -439,10 +495,6 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab
logger.exception(u'Failed to save and publish [%s]: [%s]', course_id, e.message)
return False, e, e.message
def _flatten(self, attrs):
"""Transform a list of attribute names and values into a dictionary keyed on the names."""
return {attr['name']: attr['value'] for attr in attrs}
class PartnerSerializer(serializers.ModelSerializer):
"""Serializer for the Partner object"""
......
......@@ -8,7 +8,7 @@ from django.test import RequestFactory
from django.urls import reverse
from oscar.core.loading import get_model
from ecommerce.core.constants import COUPON_PRODUCT_CLASS_NAME, COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME
from ecommerce.core.constants import COUPON_PRODUCT_CLASS_NAME
from ecommerce.coupons.tests.mixins import CouponMixin
from ecommerce.courses.tests.factories import CourseFactory
from ecommerce.extensions.api.serializers import ProductSerializer
......@@ -135,49 +135,6 @@ class ProductViewSetTests(ProductViewSetBase):
self.assertDictEqual(json.loads(response.content), expected)
class ProductViewSetCourseEntitlementTests(ProductViewSetBase):
def setUp(self):
self.entitlement_data = {
"product_class": COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME,
"title": "Test Course",
"price": 50,
"expires": "2018-10-10T00:00:00Z",
"attribute_values": [
{
"name": "certificate_type",
"code": "certificate_type",
"value": "verified"
},
{
"name": "UUID",
"code": "UUID",
"value": "f9044e15-133f-4a4f-b587-99530e8a8e88"
}
],
"is_available_to_buy": "false"
}
super(ProductViewSetCourseEntitlementTests, self).setUp()
def test_entitlement_post(self):
""" Verify the view allows individual Course Entitlement products to be made via post"""
response = self.client.post('/api/v2/products/', json.dumps(self.entitlement_data), JSON_CONTENT_TYPE)
self.assertEqual(response.status_code, 201)
def test_entitlement_post_bad_request(self):
""" Verify the view allows individual Course Entitlement products to be made via post"""
bad_entitlement_data = self.entitlement_data
bad_entitlement_data['attribute_values'] = []
response = self.client.post('/api/v2/products/', json.dumps(bad_entitlement_data), JSON_CONTENT_TYPE)
self.assertEqual(response.status_code, 400)
def test_non_entitlement_post(self):
""" Verify the view allows individual Course Entitlement products to be made via post"""
bad_entitlement_data = self.entitlement_data
bad_entitlement_data['product_class'] = 'Seat'
response = self.client.post('/api/v2/products/', json.dumps(bad_entitlement_data), JSON_CONTENT_TYPE)
self.assertEqual(response.status_code, 400)
class ProductViewSetCouponTests(CouponMixin, ProductViewSetBase):
def test_coupon_product_details(self):
"""Verify the endpoint returns all coupon information."""
......
"""HTTP endpoints for interacting with products."""
from django.db.models import Q
from django.http import HttpResponseBadRequest
from oscar.core.loading import get_model
from rest_framework import filters, status
from rest_framework import filters
from rest_framework.permissions import IsAdminUser, IsAuthenticated
from rest_framework.response import Response
from rest_framework_extensions.mixins import NestedViewSetMixin
from ecommerce.core.constants import COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME
from ecommerce.entitlements.utils import create_or_update_course_entitlement
from ecommerce.extensions.api import serializers
from ecommerce.extensions.api.filters import ProductFilter
from ecommerce.extensions.api.v2.views import NonDestroyableModelViewSet
......@@ -33,39 +29,3 @@ class ProductViewSet(NestedViewSetMixin, NonDestroyableModelViewSet):
Q(stockrecords__partner=self.request.site.siteconfiguration.partner) |
Q(course__site=self.request.site)
)
def create(self, request, *args, **kwargs):
product_class = request.data.get('product_class')
if product_class == COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME:
product_creation_fields = {
'partner': request.site.siteconfiguration.partner,
'name': request.data.get('title'),
'price': request.data.get('price'),
'certificate_type': self._fetch_value_from_attribute_values('certificate_type'),
'UUID': self._fetch_value_from_attribute_values('UUID')
}
for attribute_name, attribute_value in product_creation_fields.items():
if attribute_value is None:
bad_rqst = 'Missing or bad value for: {}, required for Entitlement creation.'.format(attribute_name)
return HttpResponseBadRequest(bad_rqst)
entitlement = create_or_update_course_entitlement(
product_creation_fields['certificate_type'],
product_creation_fields['price'],
product_creation_fields['partner'],
product_creation_fields['UUID'],
product_creation_fields['name']
)
data = self.serializer_class(entitlement, context={'request': request}).data
return Response(data, status=status.HTTP_201_CREATED)
else:
bad_rqst = "Product API only supports POST for {} products".format(COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME)
return HttpResponseBadRequest(bad_rqst)
def _fetch_value_from_attribute_values(self, attribute_name):
attributes = {attribute.get('name'): attribute.get('value') for attribute in self.request.data.get('attribute_values')} # pylint: disable=line-too-long
val = attributes.get(attribute_name)
return val
......@@ -7,6 +7,7 @@ from oscar.core.utils import slugify
from oscar.test import factories
from ecommerce.core.constants import (
COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME,
ENROLLMENT_CODE_PRODUCT_CLASS_NAME,
ENROLLMENT_CODE_SWITCH,
SEAT_PRODUCT_CLASS_NAME
......@@ -35,14 +36,16 @@ class DiscoveryTestMixin(object):
super(DiscoveryTestMixin, self).setUp()
# Force the creation of a seat ProductClass
self.entitlement_product_class # pylint: disable=pointless-statement
self.seat_product_class # pylint: disable=pointless-statement
self.enrollment_code_product_class # pylint: disable=pointless-statement
category_name = 'Seats'
try:
self.category = Category.objects.get(name=category_name)
except Category.DoesNotExist:
self.category = factories.CategoryFactory(name=category_name)
for category_name in ['Course Entitlements', 'Seats']:
try:
Category.objects.get(name=category_name)
except Category.DoesNotExist:
factories.CategoryFactory(name=category_name)
self.category = Category.objects.get(name='Seats')
def create_course_and_seat(
self, course_id=None, seat_type='verified', id_verification=False, price=10, partner=None
......@@ -93,6 +96,17 @@ class DiscoveryTestMixin(object):
return pc
@property
def entitlement_product_class(self):
attributes = (
('certificate_type', 'text'),
('UUID', 'text'),
)
product_class = self._create_product_class(
COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME, slugify(COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME), attributes
)
return product_class
@property
def seat_product_class(self):
attributes = (
('certificate_type', 'text'),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment