Commit d7e83538 by Michael Terry Committed by Michael Terry

Support creating entitlements

Allow creating entitlement products via the AtomicPublication API.
Just pass them in like normal seat products, with an entitlement
product class instead of a seat class.

And remove the unused /products create endpoint (that had been
intended to support this workflow, but it's not atomic, so we'll
keep using the generic publication endpoint with new support for
entitlements and drop the unused create-one-entitlement endpoint).

LEARNER-3891
parent 2a281d0a
...@@ -22,8 +22,15 @@ def create_parent_course_entitlement(name, UUID): ...@@ -22,8 +22,15 @@ def create_parent_course_entitlement(name, UUID):
parent, created = Product.objects.get_or_create( parent, created = Product.objects.get_or_create(
structure=Product.PARENT, structure=Product.PARENT,
product_class=ProductClass.objects.get(name=COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME), product_class=ProductClass.objects.get(name=COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME),
title='Parent Course Entitlement for {}'.format(name), attributes__name='UUID',
attribute_values__value_text=UUID,
defaults={
'title': 'Parent Course Entitlement for {}'.format(name),
'is_discountable': True,
},
) )
parent.attr.UUID = UUID
parent.attr.save()
if created: if created:
logger.debug('Created new parent course_entitlement [%d] for [%s].', parent.id, UUID) logger.debug('Created new parent course_entitlement [%d] for [%s].', parent.id, UUID)
...@@ -31,10 +38,6 @@ def create_parent_course_entitlement(name, UUID): ...@@ -31,10 +38,6 @@ def create_parent_course_entitlement(name, UUID):
logger.debug('Parent course_entitlement [%d] already exists for [%s].', parent.id, UUID) logger.debug('Parent course_entitlement [%d] already exists for [%s].', parent.id, UUID)
ProductCategory.objects.get_or_create(category=Category.objects.get(name='Course Entitlements'), product=parent) ProductCategory.objects.get_or_create(category=Category.objects.get(name='Course Entitlements'), product=parent)
parent.title = 'Parent Course Entitlement for {}'.format(name)
parent.is_discountable = True
parent.attr.UUID = UUID
parent.save()
return parent, created return parent, created
...@@ -45,8 +48,11 @@ def create_or_update_course_entitlement(certificate_type, price, partner, UUID, ...@@ -45,8 +48,11 @@ def create_or_update_course_entitlement(certificate_type, price, partner, UUID,
certificate_type = certificate_type.lower() certificate_type = certificate_type.lower()
UUID = unicode(UUID) UUID = unicode(UUID)
uuid_query = Q(
attributes__name='UUID',
attribute_values__value_text=UUID,
)
certificate_type_query = Q( certificate_type_query = Q(
title='Course {}'.format(name),
attributes__name='certificate_type', attributes__name='certificate_type',
attribute_values__value_text=certificate_type, attribute_values__value_text=certificate_type,
) )
...@@ -54,7 +60,7 @@ def create_or_update_course_entitlement(certificate_type, price, partner, UUID, ...@@ -54,7 +60,7 @@ def create_or_update_course_entitlement(certificate_type, price, partner, UUID,
try: try:
parent_entitlement, __ = create_parent_course_entitlement(name, UUID) parent_entitlement, __ = create_parent_course_entitlement(name, UUID)
all_products = parent_entitlement.children.all().prefetch_related('stockrecords') all_products = parent_entitlement.children.all().prefetch_related('stockrecords')
course_entitlement = all_products.get(certificate_type_query) course_entitlement = all_products.filter(uuid_query).get(certificate_type_query)
except Product.DoesNotExist: except Product.DoesNotExist:
course_entitlement = Product() course_entitlement = Product()
......
...@@ -14,9 +14,12 @@ from oscar.core.loading import get_class, get_model ...@@ -14,9 +14,12 @@ from oscar.core.loading import get_class, get_model
from rest_framework import serializers from rest_framework import serializers
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from ecommerce.core.constants import COURSE_ID_REGEX, ENROLLMENT_CODE_SWITCH, ISO_8601_FORMAT, SEAT_PRODUCT_CLASS_NAME from ecommerce.core.constants import (COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME, COURSE_ID_REGEX, ENROLLMENT_CODE_SWITCH,
ISO_8601_FORMAT, SEAT_PRODUCT_CLASS_NAME)
from ecommerce.core.url_utils import get_ecommerce_url from ecommerce.core.url_utils import get_ecommerce_url
from ecommerce.courses.models import Course from ecommerce.courses.models import Course
from ecommerce.courses.utils import get_course_info_from_catalog
from ecommerce.entitlements.utils import create_or_update_course_entitlement
from ecommerce.invoice.models import Invoice from ecommerce.invoice.models import Invoice
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -97,6 +100,11 @@ def retrieve_voucher_usage(obj): ...@@ -97,6 +100,11 @@ def retrieve_voucher_usage(obj):
return retrieve_voucher(obj).usage return retrieve_voucher(obj).usage
def _flatten(attrs):
"""Transform a list of attribute names and values into a dictionary keyed on the names."""
return {attr['name']: attr['value'] for attr in attrs}
class ProductPaymentInfoMixin(serializers.ModelSerializer): class ProductPaymentInfoMixin(serializers.ModelSerializer):
""" Mixin class used for retrieving price information from products. """ """ Mixin class used for retrieving price information from products. """
price = serializers.SerializerMethodField() price = serializers.SerializerMethodField()
...@@ -317,6 +325,73 @@ class CourseSerializer(serializers.HyperlinkedModelSerializer): ...@@ -317,6 +325,73 @@ class CourseSerializer(serializers.HyperlinkedModelSerializer):
} }
class EntitlementProductHelper(object):
@staticmethod
def validate(product):
attrs = _flatten(product['attribute_values'])
if 'certificate_type' not in attrs:
raise serializers.ValidationError(_(u"Products must have a certificate type."))
if 'price' not in product:
raise serializers.ValidationError(_(u"Products must have a price."))
@staticmethod
def save(partner, course, uuid, product):
attrs = _flatten(product['attribute_values'])
# Extract arguments required for Seat creation, deserializing as necessary.
certificate_type = attrs.get('certificate_type')
price = Decimal(product['price'])
create_or_update_course_entitlement(
certificate_type,
price,
partner,
uuid,
course.name
)
class SeatProductHelper(object):
@staticmethod
def validate(product):
attrs = _flatten(product['attribute_values'])
if attrs.get('id_verification_required') is None:
raise serializers.ValidationError(_(u"Products must indicate whether ID verification is required."))
# Verify that a price is present.
if product.get('price') is None:
raise serializers.ValidationError(_(u"Products must have a price."))
@staticmethod
def save(partner, course, product, create_enrollment_code):
attrs = _flatten(product['attribute_values'])
# Extract arguments required for Seat creation, deserializing as necessary.
certificate_type = attrs.get('certificate_type', '')
id_verification_required = attrs['id_verification_required']
price = Decimal(product['price'])
# Extract arguments which are optional for Seat creation, deserializing as necessary.
expires = product.get('expires')
expires = parse(expires) if expires else None
credit_provider = attrs.get('credit_provider')
credit_hours = attrs.get('credit_hours')
credit_hours = int(credit_hours) if credit_hours else None
course.create_or_update_seat(
certificate_type,
id_verification_required,
price,
partner,
expires=expires,
credit_provider=credit_provider,
credit_hours=credit_hours,
create_enrollment_code=create_enrollment_code
)
class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=abstract-method class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""Serializer for saving and publishing a Course and associated products. """Serializer for saving and publishing a Course and associated products.
...@@ -337,22 +412,17 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab ...@@ -337,22 +412,17 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab
def validate_products(self, products): def validate_products(self, products):
"""Validate product data.""" """Validate product data."""
for product in products: for product in products:
# Verify that each product is intended to be a Seat.
product_class = product.get('product_class') product_class = product.get('product_class')
if product_class != SEAT_PRODUCT_CLASS_NAME:
if product_class == COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME:
EntitlementProductHelper.validate(product)
elif product_class == SEAT_PRODUCT_CLASS_NAME:
SeatProductHelper.validate(product)
else:
raise serializers.ValidationError( raise serializers.ValidationError(
_(u"Invalid product class [{product_class}] requested.".format(product_class=product_class)) _(u"Invalid product class [{product_class}] requested.").format(product_class=product_class)
) )
# Verify that attributes required to create a Seat are present.
attrs = self._flatten(product['attribute_values'])
if attrs.get('id_verification_required') is None:
raise serializers.ValidationError(_(u"Products must indicate whether ID verification is required."))
# Verify that a price is present.
if product.get('price') is None:
raise serializers.ValidationError(_(u"Products must have a price."))
return products return products
def get_partner(self): def get_partner(self):
...@@ -395,34 +465,20 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab ...@@ -395,34 +465,20 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab
course.verification_deadline = course_verification_deadline course.verification_deadline = course_verification_deadline
course.save() course.save()
# Fetch full course info (from seat product, because this queries based on attr.course_key)
course_info = get_course_info_from_catalog(course.site, course.parent_seat_product)
create_enrollment_code = False create_enrollment_code = False
if waffle.switch_is_active(ENROLLMENT_CODE_SWITCH) and site.siteconfiguration.enable_enrollment_codes: if waffle.switch_is_active(ENROLLMENT_CODE_SWITCH) and site.siteconfiguration.enable_enrollment_codes:
create_enrollment_code = create_or_activate_enrollment_code create_enrollment_code = create_or_activate_enrollment_code
for product in products: for product in products:
attrs = self._flatten(product['attribute_values']) product_class = product.get('product_class')
# Extract arguments required for Seat creation, deserializing as necessary.
certificate_type = attrs.get('certificate_type', '') if product_class == COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME:
id_verification_required = attrs['id_verification_required'] EntitlementProductHelper.save(partner, course, course_info['UUID'], product)
price = Decimal(product['price']) elif product_class == SEAT_PRODUCT_CLASS_NAME:
SeatProductHelper.save(partner, course, product, create_enrollment_code)
# Extract arguments which are optional for Seat creation, deserializing as necessary.
expires = product.get('expires')
expires = parse(expires) if expires else None
credit_provider = attrs.get('credit_provider')
credit_hours = attrs.get('credit_hours')
credit_hours = int(credit_hours) if credit_hours else None
course.create_or_update_seat(
certificate_type,
id_verification_required,
price,
partner,
expires=expires,
credit_provider=credit_provider,
credit_hours=credit_hours,
create_enrollment_code=create_enrollment_code
)
if course.get_enrollment_code(): if course.get_enrollment_code():
course.toggle_enrollment_code_status(is_active=create_enrollment_code) course.toggle_enrollment_code_status(is_active=create_enrollment_code)
...@@ -439,10 +495,6 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab ...@@ -439,10 +495,6 @@ class AtomicPublicationSerializer(serializers.Serializer): # pylint: disable=ab
logger.exception(u'Failed to save and publish [%s]: [%s]', course_id, e.message) logger.exception(u'Failed to save and publish [%s]: [%s]', course_id, e.message)
return False, e, e.message return False, e, e.message
def _flatten(self, attrs):
"""Transform a list of attribute names and values into a dictionary keyed on the names."""
return {attr['name']: attr['value'] for attr in attrs}
class PartnerSerializer(serializers.ModelSerializer): class PartnerSerializer(serializers.ModelSerializer):
"""Serializer for the Partner object""" """Serializer for the Partner object"""
......
...@@ -8,7 +8,7 @@ from django.test import RequestFactory ...@@ -8,7 +8,7 @@ from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from oscar.core.loading import get_model from oscar.core.loading import get_model
from ecommerce.core.constants import COUPON_PRODUCT_CLASS_NAME, COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME from ecommerce.core.constants import COUPON_PRODUCT_CLASS_NAME
from ecommerce.coupons.tests.mixins import CouponMixin from ecommerce.coupons.tests.mixins import CouponMixin
from ecommerce.courses.tests.factories import CourseFactory from ecommerce.courses.tests.factories import CourseFactory
from ecommerce.extensions.api.serializers import ProductSerializer from ecommerce.extensions.api.serializers import ProductSerializer
...@@ -135,49 +135,6 @@ class ProductViewSetTests(ProductViewSetBase): ...@@ -135,49 +135,6 @@ class ProductViewSetTests(ProductViewSetBase):
self.assertDictEqual(json.loads(response.content), expected) self.assertDictEqual(json.loads(response.content), expected)
class ProductViewSetCourseEntitlementTests(ProductViewSetBase):
def setUp(self):
self.entitlement_data = {
"product_class": COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME,
"title": "Test Course",
"price": 50,
"expires": "2018-10-10T00:00:00Z",
"attribute_values": [
{
"name": "certificate_type",
"code": "certificate_type",
"value": "verified"
},
{
"name": "UUID",
"code": "UUID",
"value": "f9044e15-133f-4a4f-b587-99530e8a8e88"
}
],
"is_available_to_buy": "false"
}
super(ProductViewSetCourseEntitlementTests, self).setUp()
def test_entitlement_post(self):
""" Verify the view allows individual Course Entitlement products to be made via post"""
response = self.client.post('/api/v2/products/', json.dumps(self.entitlement_data), JSON_CONTENT_TYPE)
self.assertEqual(response.status_code, 201)
def test_entitlement_post_bad_request(self):
""" Verify the view allows individual Course Entitlement products to be made via post"""
bad_entitlement_data = self.entitlement_data
bad_entitlement_data['attribute_values'] = []
response = self.client.post('/api/v2/products/', json.dumps(bad_entitlement_data), JSON_CONTENT_TYPE)
self.assertEqual(response.status_code, 400)
def test_non_entitlement_post(self):
""" Verify the view allows individual Course Entitlement products to be made via post"""
bad_entitlement_data = self.entitlement_data
bad_entitlement_data['product_class'] = 'Seat'
response = self.client.post('/api/v2/products/', json.dumps(bad_entitlement_data), JSON_CONTENT_TYPE)
self.assertEqual(response.status_code, 400)
class ProductViewSetCouponTests(CouponMixin, ProductViewSetBase): class ProductViewSetCouponTests(CouponMixin, ProductViewSetBase):
def test_coupon_product_details(self): def test_coupon_product_details(self):
"""Verify the endpoint returns all coupon information.""" """Verify the endpoint returns all coupon information."""
......
"""HTTP endpoints for interacting with products.""" """HTTP endpoints for interacting with products."""
from django.db.models import Q from django.db.models import Q
from django.http import HttpResponseBadRequest
from oscar.core.loading import get_model from oscar.core.loading import get_model
from rest_framework import filters, status from rest_framework import filters
from rest_framework.permissions import IsAdminUser, IsAuthenticated from rest_framework.permissions import IsAdminUser, IsAuthenticated
from rest_framework.response import Response
from rest_framework_extensions.mixins import NestedViewSetMixin from rest_framework_extensions.mixins import NestedViewSetMixin
from ecommerce.core.constants import COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME
from ecommerce.entitlements.utils import create_or_update_course_entitlement
from ecommerce.extensions.api import serializers from ecommerce.extensions.api import serializers
from ecommerce.extensions.api.filters import ProductFilter from ecommerce.extensions.api.filters import ProductFilter
from ecommerce.extensions.api.v2.views import NonDestroyableModelViewSet from ecommerce.extensions.api.v2.views import NonDestroyableModelViewSet
...@@ -33,39 +29,3 @@ class ProductViewSet(NestedViewSetMixin, NonDestroyableModelViewSet): ...@@ -33,39 +29,3 @@ class ProductViewSet(NestedViewSetMixin, NonDestroyableModelViewSet):
Q(stockrecords__partner=self.request.site.siteconfiguration.partner) | Q(stockrecords__partner=self.request.site.siteconfiguration.partner) |
Q(course__site=self.request.site) Q(course__site=self.request.site)
) )
def create(self, request, *args, **kwargs):
product_class = request.data.get('product_class')
if product_class == COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME:
product_creation_fields = {
'partner': request.site.siteconfiguration.partner,
'name': request.data.get('title'),
'price': request.data.get('price'),
'certificate_type': self._fetch_value_from_attribute_values('certificate_type'),
'UUID': self._fetch_value_from_attribute_values('UUID')
}
for attribute_name, attribute_value in product_creation_fields.items():
if attribute_value is None:
bad_rqst = 'Missing or bad value for: {}, required for Entitlement creation.'.format(attribute_name)
return HttpResponseBadRequest(bad_rqst)
entitlement = create_or_update_course_entitlement(
product_creation_fields['certificate_type'],
product_creation_fields['price'],
product_creation_fields['partner'],
product_creation_fields['UUID'],
product_creation_fields['name']
)
data = self.serializer_class(entitlement, context={'request': request}).data
return Response(data, status=status.HTTP_201_CREATED)
else:
bad_rqst = "Product API only supports POST for {} products".format(COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME)
return HttpResponseBadRequest(bad_rqst)
def _fetch_value_from_attribute_values(self, attribute_name):
attributes = {attribute.get('name'): attribute.get('value') for attribute in self.request.data.get('attribute_values')} # pylint: disable=line-too-long
val = attributes.get(attribute_name)
return val
...@@ -7,6 +7,7 @@ from oscar.core.utils import slugify ...@@ -7,6 +7,7 @@ from oscar.core.utils import slugify
from oscar.test import factories from oscar.test import factories
from ecommerce.core.constants import ( from ecommerce.core.constants import (
COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME,
ENROLLMENT_CODE_PRODUCT_CLASS_NAME, ENROLLMENT_CODE_PRODUCT_CLASS_NAME,
ENROLLMENT_CODE_SWITCH, ENROLLMENT_CODE_SWITCH,
SEAT_PRODUCT_CLASS_NAME SEAT_PRODUCT_CLASS_NAME
...@@ -35,14 +36,16 @@ class DiscoveryTestMixin(object): ...@@ -35,14 +36,16 @@ class DiscoveryTestMixin(object):
super(DiscoveryTestMixin, self).setUp() super(DiscoveryTestMixin, self).setUp()
# Force the creation of a seat ProductClass # Force the creation of a seat ProductClass
self.entitlement_product_class # pylint: disable=pointless-statement
self.seat_product_class # pylint: disable=pointless-statement self.seat_product_class # pylint: disable=pointless-statement
self.enrollment_code_product_class # pylint: disable=pointless-statement self.enrollment_code_product_class # pylint: disable=pointless-statement
category_name = 'Seats' for category_name in ['Course Entitlements', 'Seats']:
try: try:
self.category = Category.objects.get(name=category_name) Category.objects.get(name=category_name)
except Category.DoesNotExist: except Category.DoesNotExist:
self.category = factories.CategoryFactory(name=category_name) factories.CategoryFactory(name=category_name)
self.category = Category.objects.get(name='Seats')
def create_course_and_seat( def create_course_and_seat(
self, course_id=None, seat_type='verified', id_verification=False, price=10, partner=None self, course_id=None, seat_type='verified', id_verification=False, price=10, partner=None
...@@ -93,6 +96,17 @@ class DiscoveryTestMixin(object): ...@@ -93,6 +96,17 @@ class DiscoveryTestMixin(object):
return pc return pc
@property @property
def entitlement_product_class(self):
attributes = (
('certificate_type', 'text'),
('UUID', 'text'),
)
product_class = self._create_product_class(
COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME, slugify(COURSE_ENTITLEMENT_PRODUCT_CLASS_NAME), attributes
)
return product_class
@property
def seat_product_class(self): def seat_product_class(self):
attributes = ( attributes = (
('certificate_type', 'text'), ('certificate_type', 'text'),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment