Commit 3e230d70 by Vedran Karačić Committed by GitHub

Merge pull request #888 from edx/vkaracic/fixes

Fixes
parents 5bb3ab35 e8e5f5a5
...@@ -447,6 +447,7 @@ class CouponSerializer(ProductPaymentInfoMixin, serializers.ModelSerializer): ...@@ -447,6 +447,7 @@ class CouponSerializer(ProductPaymentInfoMixin, serializers.ModelSerializer):
quantity = serializers.SerializerMethodField() quantity = serializers.SerializerMethodField()
start_date = serializers.SerializerMethodField() start_date = serializers.SerializerMethodField()
voucher_type = serializers.SerializerMethodField() voucher_type = serializers.SerializerMethodField()
seats = serializers.SerializerMethodField()
def retrieve_benefit(self, obj): def retrieve_benefit(self, obj):
"""Helper method to retrieve the benefit from voucher. """ """Helper method to retrieve the benefit from voucher. """
...@@ -551,6 +552,18 @@ class CouponSerializer(ProductPaymentInfoMixin, serializers.ModelSerializer): ...@@ -551,6 +552,18 @@ class CouponSerializer(ProductPaymentInfoMixin, serializers.ModelSerializer):
def get_voucher_type(self, obj): def get_voucher_type(self, obj):
return self.retrieve_voucher_usage(obj) return self.retrieve_voucher_usage(obj)
def get_seats(self, obj):
offer = self.retrieve_offer(obj)
_range = offer.condition.range
request = self.context['request']
if _range.catalog:
stockrecords = _range.catalog.stock_records.all()
seats = Product.objects.filter(id__in=[sr.product.id for sr in stockrecords])
serializer = ProductSerializer(seats, many=True, context={'request': request})
return serializer.data
else:
return None
class Meta(object): class Meta(object):
model = Product model = Product
fields = ( fields = (
...@@ -558,7 +571,7 @@ class CouponSerializer(ProductPaymentInfoMixin, serializers.ModelSerializer): ...@@ -558,7 +571,7 @@ class CouponSerializer(ProductPaymentInfoMixin, serializers.ModelSerializer):
'categories', 'client', 'code', 'code_status', 'coupon_type', 'categories', 'client', 'code', 'code_status', 'coupon_type',
'course_seat_types', 'end_date', 'id', 'last_edited', 'max_uses', 'course_seat_types', 'end_date', 'id', 'last_edited', 'max_uses',
'note', 'num_uses', 'payment_information', 'price', 'quantity', 'note', 'num_uses', 'payment_information', 'price', 'quantity',
'start_date', 'title', 'voucher_type' 'start_date', 'title', 'voucher_type', 'seats'
) )
......
...@@ -45,6 +45,11 @@ class Basket(AbstractBasket): ...@@ -45,6 +45,11 @@ class Basket(AbstractBasket):
return basket return basket
def clear_vouchers(self):
"""Remove all vouchers applied to the basket."""
for v in self.vouchers.all():
self.vouchers.remove(v)
def __unicode__(self): def __unicode__(self):
return _(u"{id} - {status} basket (owner: {owner}, lines: {num_lines})").format( return _(u"{id} - {status} basket (owner: {owner}, lines: {num_lines})").format(
id=self.id, id=self.id,
......
...@@ -7,6 +7,7 @@ from ecommerce.core.constants import ENROLLMENT_CODE_PRODUCT_CLASS_NAME, ENROLLM ...@@ -7,6 +7,7 @@ from ecommerce.core.constants import ENROLLMENT_CODE_PRODUCT_CLASS_NAME, ENROLLM
from ecommerce.core.tests import toggle_switch from ecommerce.core.tests import toggle_switch
from ecommerce.courses.tests.factories import CourseFactory from ecommerce.courses.tests.factories import CourseFactory
from ecommerce.extensions.basket.utils import prepare_basket from ecommerce.extensions.basket.utils import prepare_basket
from ecommerce.extensions.catalogue.tests.mixins import CourseCatalogTestMixin
from ecommerce.extensions.partner.models import StockRecord from ecommerce.extensions.partner.models import StockRecord
from ecommerce.extensions.test.factories import prepare_voucher from ecommerce.extensions.test.factories import prepare_voucher
from ecommerce.referrals.models import Referral from ecommerce.referrals.models import Referral
...@@ -19,7 +20,7 @@ Product = get_model('catalogue', 'Product') ...@@ -19,7 +20,7 @@ Product = get_model('catalogue', 'Product')
@ddt.ddt @ddt.ddt
class BasketUtilsTests(TestCase): class BasketUtilsTests(CourseCatalogTestMixin, TestCase):
""" Tests for basket utility functions. """ """ Tests for basket utility functions. """
def setUp(self): def setUp(self):
...@@ -51,20 +52,22 @@ class BasketUtilsTests(TestCase): ...@@ -51,20 +52,22 @@ class BasketUtilsTests(TestCase):
self.assertEqual(basket.total_excl_tax, 90.00) self.assertEqual(basket.total_excl_tax, 90.00)
def test_prepare_basket_enrollment_with_voucher(self): def test_prepare_basket_enrollment_with_voucher(self):
"""Verify the basket does not contain a voucher if enrollment code is added to it."""
course = CourseFactory() course = CourseFactory()
toggle_switch(ENROLLMENT_CODE_SWITCH, True) toggle_switch(ENROLLMENT_CODE_SWITCH, True)
course.create_or_update_seat('verified', False, 10, self.partner, create_enrollment_code=True) course.create_or_update_seat('verified', False, 10, self.partner, create_enrollment_code=True)
enrollment_code = Product.objects.get(product_class__name=ENROLLMENT_CODE_PRODUCT_CLASS_NAME) enrollment_code = Product.objects.get(product_class__name=ENROLLMENT_CODE_PRODUCT_CLASS_NAME)
# Prepare a product with price of 100 and a voucher with 10% discount for that product. voucher, product = prepare_voucher()
product = ProductFactory(stockrecords__price_excl_tax=100)
new_range = RangeFactory(products=[product, ])
voucher, product = prepare_voucher(_range=new_range, benefit_value=10)
basket = prepare_basket(self.request, product, voucher) basket = prepare_basket(self.request, product, voucher)
self.assertIsNotNone(basket) self.assertIsNotNone(basket)
self.assertEqual(basket.vouchers.count(), 1) self.assertEqual(basket.all_lines()[0].product, product)
basket = prepare_basket(self.request, enrollment_code) self.assertTrue(basket.contains_a_voucher)
basket = prepare_basket(self.request, enrollment_code, voucher)
self.assertIsNotNone(basket) self.assertIsNotNone(basket)
self.assertEqual(basket.vouchers.count(), 0) self.assertEqual(basket.all_lines()[0].product, enrollment_code)
self.assertFalse(basket.contains_a_voucher)
def test_multiple_vouchers(self): def test_multiple_vouchers(self):
""" Verify only the last entered voucher is contained in the basket. """ """ Verify only the last entered voucher is contained in the basket. """
......
...@@ -34,13 +34,13 @@ def prepare_basket(request, product, voucher=None): ...@@ -34,13 +34,13 @@ def prepare_basket(request, product, voucher=None):
basket = Basket.get_basket(request.user, request.site) basket = Basket.get_basket(request.user, request.site)
basket.flush() basket.flush()
basket.add_product(product, 1) basket.add_product(product, 1)
if voucher or product.get_product_class().name == ENROLLMENT_CODE_PRODUCT_CLASS_NAME: if product.get_product_class().name == ENROLLMENT_CODE_PRODUCT_CLASS_NAME:
for v in basket.vouchers.all(): basket.clear_vouchers()
basket.vouchers.remove(v) elif voucher:
if voucher: basket.clear_vouchers()
basket.vouchers.add(voucher) basket.vouchers.add(voucher)
Applicator().apply(basket, request.user, request) Applicator().apply(basket, request.user, request)
logger.info('Applied Voucher [%s] to basket [%s].', voucher.code, basket.id) logger.info('Applied Voucher [%s] to basket [%s].', voucher.code, basket.id)
affiliate_id = request.COOKIES.get(settings.AFFILIATE_COOKIE_KEY) affiliate_id = request.COOKIES.get(settings.AFFILIATE_COOKIE_KEY)
if affiliate_id: if affiliate_id:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment