Commit e8e5f5a5 by Vedran Karacic

prepare_basket adjustments

parent c9fb96a7
...@@ -45,6 +45,11 @@ class Basket(AbstractBasket): ...@@ -45,6 +45,11 @@ class Basket(AbstractBasket):
return basket return basket
def clear_vouchers(self):
"""Remove all vouchers applied to the basket."""
for v in self.vouchers.all():
self.vouchers.remove(v)
def __unicode__(self): def __unicode__(self):
return _(u"{id} - {status} basket (owner: {owner}, lines: {num_lines})").format( return _(u"{id} - {status} basket (owner: {owner}, lines: {num_lines})").format(
id=self.id, id=self.id,
......
...@@ -7,6 +7,7 @@ from ecommerce.core.constants import ENROLLMENT_CODE_PRODUCT_CLASS_NAME, ENROLLM ...@@ -7,6 +7,7 @@ from ecommerce.core.constants import ENROLLMENT_CODE_PRODUCT_CLASS_NAME, ENROLLM
from ecommerce.core.tests import toggle_switch from ecommerce.core.tests import toggle_switch
from ecommerce.courses.tests.factories import CourseFactory from ecommerce.courses.tests.factories import CourseFactory
from ecommerce.extensions.basket.utils import prepare_basket from ecommerce.extensions.basket.utils import prepare_basket
from ecommerce.extensions.catalogue.tests.mixins import CourseCatalogTestMixin
from ecommerce.extensions.partner.models import StockRecord from ecommerce.extensions.partner.models import StockRecord
from ecommerce.extensions.test.factories import prepare_voucher from ecommerce.extensions.test.factories import prepare_voucher
from ecommerce.referrals.models import Referral from ecommerce.referrals.models import Referral
...@@ -19,7 +20,7 @@ Product = get_model('catalogue', 'Product') ...@@ -19,7 +20,7 @@ Product = get_model('catalogue', 'Product')
@ddt.ddt @ddt.ddt
class BasketUtilsTests(TestCase): class BasketUtilsTests(CourseCatalogTestMixin, TestCase):
""" Tests for basket utility functions. """ """ Tests for basket utility functions. """
def setUp(self): def setUp(self):
...@@ -51,20 +52,22 @@ class BasketUtilsTests(TestCase): ...@@ -51,20 +52,22 @@ class BasketUtilsTests(TestCase):
self.assertEqual(basket.total_excl_tax, 90.00) self.assertEqual(basket.total_excl_tax, 90.00)
def test_prepare_basket_enrollment_with_voucher(self): def test_prepare_basket_enrollment_with_voucher(self):
"""Verify the basket does not contain a voucher if enrollment code is added to it."""
course = CourseFactory() course = CourseFactory()
toggle_switch(ENROLLMENT_CODE_SWITCH, True) toggle_switch(ENROLLMENT_CODE_SWITCH, True)
course.create_or_update_seat('verified', False, 10, self.partner, create_enrollment_code=True) course.create_or_update_seat('verified', False, 10, self.partner, create_enrollment_code=True)
enrollment_code = Product.objects.get(product_class__name=ENROLLMENT_CODE_PRODUCT_CLASS_NAME) enrollment_code = Product.objects.get(product_class__name=ENROLLMENT_CODE_PRODUCT_CLASS_NAME)
# Prepare a product with price of 100 and a voucher with 10% discount for that product. voucher, product = prepare_voucher()
product = ProductFactory(stockrecords__price_excl_tax=100)
new_range = RangeFactory(products=[product, ])
voucher, product = prepare_voucher(_range=new_range, benefit_value=10)
basket = prepare_basket(self.request, product, voucher) basket = prepare_basket(self.request, product, voucher)
self.assertIsNotNone(basket) self.assertIsNotNone(basket)
self.assertEqual(basket.vouchers.count(), 1) self.assertEqual(basket.all_lines()[0].product, product)
basket = prepare_basket(self.request, enrollment_code) self.assertTrue(basket.contains_a_voucher)
basket = prepare_basket(self.request, enrollment_code, voucher)
self.assertIsNotNone(basket) self.assertIsNotNone(basket)
self.assertEqual(basket.vouchers.count(), 0) self.assertEqual(basket.all_lines()[0].product, enrollment_code)
self.assertFalse(basket.contains_a_voucher)
def test_multiple_vouchers(self): def test_multiple_vouchers(self):
""" Verify only the last entered voucher is contained in the basket. """ """ Verify only the last entered voucher is contained in the basket. """
......
...@@ -34,13 +34,13 @@ def prepare_basket(request, product, voucher=None): ...@@ -34,13 +34,13 @@ def prepare_basket(request, product, voucher=None):
basket = Basket.get_basket(request.user, request.site) basket = Basket.get_basket(request.user, request.site)
basket.flush() basket.flush()
basket.add_product(product, 1) basket.add_product(product, 1)
if voucher or product.get_product_class().name == ENROLLMENT_CODE_PRODUCT_CLASS_NAME: if product.get_product_class().name == ENROLLMENT_CODE_PRODUCT_CLASS_NAME:
for v in basket.vouchers.all(): basket.clear_vouchers()
basket.vouchers.remove(v) elif voucher:
if voucher: basket.clear_vouchers()
basket.vouchers.add(voucher) basket.vouchers.add(voucher)
Applicator().apply(basket, request.user, request) Applicator().apply(basket, request.user, request)
logger.info('Applied Voucher [%s] to basket [%s].', voucher.code, basket.id) logger.info('Applied Voucher [%s] to basket [%s].', voucher.code, basket.id)
affiliate_id = request.COOKIES.get(settings.AFFILIATE_COOKIE_KEY) affiliate_id = request.COOKIES.get(settings.AFFILIATE_COOKIE_KEY)
if affiliate_id: if affiliate_id:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment