Commit 96050ae1 by McKenzie Welter Committed by McKenzie Welter

check for entitlement products in addition to course run products

parent 4cfb8c90
...@@ -203,6 +203,24 @@ class DiscoveryMockMixin(object): ...@@ -203,6 +203,24 @@ class DiscoveryMockMixin(object):
] ]
) )
def mock_catalog_query_contains_endpoint(self, course_run_ids, course_uuids, absent_ids, query, discovery_api_url):
query_contains_info = {str(identifier): True for identifier in course_run_ids + course_uuids}
for identifier in absent_ids:
query_contains_info[str(identifier)] = False
query_contains_info_json = json.dumps(query_contains_info)
url = '{base}catalog/query_contains/?course_run_ids={run_ids}&course_uuids={uuids}&query={query}'.format(
base=discovery_api_url,
run_ids=",".join(course_run_id for course_run_id in course_run_ids),
uuids=",".join(str(course_uuid) for course_uuid in course_uuids),
query=query
)
httpretty.register_uri(
httpretty.GET, url,
body=query_contains_info_json,
content_type='application/json'
)
return url
def mock_catalog_contains_endpoint( def mock_catalog_contains_endpoint(
self, discovery_api_url, catalog_id=1, course_run_ids=None self, discovery_api_url, catalog_id=1, course_run_ids=None
): ):
...@@ -219,7 +237,6 @@ class DiscoveryMockMixin(object): ...@@ -219,7 +237,6 @@ class DiscoveryMockMixin(object):
catalog_contains_uri = '{}contains/?course_run_id={}'.format( catalog_contains_uri = '{}contains/?course_run_id={}'.format(
self.build_discovery_catalogs_url(discovery_api_url, catalog_id), ','.join(course_run_ids) self.build_discovery_catalogs_url(discovery_api_url, catalog_id), ','.join(course_run_ids)
) )
httpretty.register_uri( httpretty.register_uri(
method=httpretty.GET, method=httpretty.GET,
uri=catalog_contains_uri, uri=catalog_contains_uri,
......
from __future__ import unicode_literals from __future__ import unicode_literals
import logging import logging
from uuid import uuid4
from oscar.core.loading import get_model from oscar.core.loading import get_model
from oscar.core.utils import slugify from oscar.core.utils import slugify
...@@ -74,6 +75,16 @@ class DiscoveryTestMixin(object): ...@@ -74,6 +75,16 @@ class DiscoveryTestMixin(object):
seat = course.create_or_update_seat(seat_type, id_verification, price, partner) seat = course.create_or_update_seat(seat_type, id_verification, price, partner)
return course, seat return course, seat
def create_entitlement_product(self, course_uuid=None, certificate_type='verified'):
entitlement = factories.ProductFactory(
product_class=self.entitlement_product_class, stockrecords__partner=PartnerFactory(),
stockrecords__price_currency='USD'
)
entitlement.attr.UUID = course_uuid if course_uuid else uuid4()
entitlement.attr.certificate_type = certificate_type
entitlement.save()
return entitlement
def _create_product_class(self, class_name, slug, attributes): def _create_product_class(self, class_name, slug, attributes):
""" Helper method for creating product classes. """ Helper method for creating product classes.
......
from __future__ import unicode_literals from __future__ import unicode_literals
import logging
import re import re
from django.conf import settings from django.conf import settings
...@@ -21,6 +22,8 @@ from ecommerce.core.utils import get_cache_key, log_message_and_raise_validation ...@@ -21,6 +22,8 @@ from ecommerce.core.utils import get_cache_key, log_message_and_raise_validation
OFFER_PRIORITY_ENTERPRISE = 10 OFFER_PRIORITY_ENTERPRISE = 10
OFFER_PRIORITY_VOUCHER = 20 OFFER_PRIORITY_VOUCHER = 20
logger = logging.getLogger(__name__)
class Benefit(AbstractBenefit): class Benefit(AbstractBenefit):
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
...@@ -43,6 +46,77 @@ class Benefit(AbstractBenefit): ...@@ -43,6 +46,77 @@ class Benefit(AbstractBenefit):
if self.value > 100: if self.value > 100:
log_message_and_raise_validation_error('Percentage discount cannot be greater than 100') log_message_and_raise_validation_error('Percentage discount cannot be greater than 100')
def _filter_for_paid_course_products(self, lines, applicable_range):
"""" Filters out products that aren't seats or entitlements or that don't have a paid certificate type. """
return [
line for line in lines
if line.product.is_seat_product or line.product.is_course_entitlement_product and
hasattr(line.product.attr, 'certificate_type') and
line.product.attr.certificate_type.lower() in applicable_range.course_seat_types
]
def _identify_uncached_product_identifiers(self, lines, domain, partner_code, query):
"""
Checks the cache to see if each line is in the catalog range specified by the given query
and tracks identifiers for which discovery service data is still needed.
"""
course_run_ids = []
course_uuids = []
applicable_lines = lines
for line in applicable_lines:
product = line.product
cache_key = get_cache_key(
site_domain=domain,
partner_code=partner_code,
resource='catalog_query.contains',
course_id=product.course_id if product.is_seat_product else product.attr.UUID,
query=query
)
response = cache.get(cache_key)
if response is False:
applicable_lines.remove(line)
elif response is None:
if product.is_seat_product:
course_run_ids.append({'id': product.course.id, 'cache_key': cache_key, 'line': line})
else:
course_uuids.append({'id': product.attr.UUID, 'cache_key': cache_key, 'line': line})
return course_run_ids, course_uuids, applicable_lines
def get_applicable_lines(self, offer, basket, range=None): # pylint: disable=redefined-builtin
applicable_range = range if range else self.range
if applicable_range and applicable_range.catalog_query is not None:
query = applicable_range.catalog_query
applicable_lines = self._filter_for_paid_course_products(basket.all_lines(), applicable_range)
site = basket.site
partner_code = site.siteconfiguration.partner.short_code
course_run_ids, course_uuids, applicable_lines = self._identify_uncached_product_identifiers(
applicable_lines, site.domain, partner_code, query
)
if course_run_ids or course_uuids:
# Hit Discovery Service to determine if remaining courses and runs are in the range.
try:
response = site.siteconfiguration.discovery_api_client.catalog.query_contains.get(
course_run_ids=','.join([metadata['id'] for metadata in course_run_ids]),
course_uuids=','.join([metadata['id'] for metadata in course_uuids]),
query=query,
partner=partner_code
)
except Exception as err: # pylint: disable=bare-except
logger.warning(
'%s raised while attempting to contact Discovery Service for offer catalog_range data.', err
)
raise Exception('Failed to contact Discovery Service to retrieve offer catalog_range data.')
# Cache range-state individually for each course or run identifier and remove lines not in the range.
for metadata in course_run_ids + course_uuids:
in_range = response[str(metadata['id'])]
cache.set(metadata['cache_key'], in_range, settings.COURSES_API_CACHE_TIMEOUT)
if not in_range:
applicable_lines.remove(metadata['line'])
return applicable_lines
else:
return super(Benefit, self).get_applicable_lines(offer, basket, range=range) # pylint: disable=bad-super-call
class ConditionalOffer(AbstractConditionalOffer): class ConditionalOffer(AbstractConditionalOffer):
UPDATABLE_OFFER_FIELDS = ['email_domains', 'max_uses'] UPDATABLE_OFFER_FIELDS = ['email_domains', 'max_uses']
...@@ -157,6 +231,10 @@ class ConditionalOffer(AbstractConditionalOffer): ...@@ -157,6 +231,10 @@ class ConditionalOffer(AbstractConditionalOffer):
""" """
if not self.is_email_valid(basket.owner.email): if not self.is_email_valid(basket.owner.email):
return False return False
if self.benefit.range and self.benefit.range.catalog_query:
# The condition is only satisfied if all basket lines are in the offer range
return len(self.benefit.get_applicable_lines(self, basket)) == basket.all_lines().count()
return super(ConditionalOffer, self).is_condition_satisfied(basket) # pylint: disable=bad-super-call return super(ConditionalOffer, self).is_condition_satisfied(basket) # pylint: disable=bad-super-call
...@@ -231,33 +309,6 @@ class Range(AbstractRange): ...@@ -231,33 +309,6 @@ class Range(AbstractRange):
if self.course_seat_types: if self.course_seat_types:
validate_credit_seat_type(self.course_seat_types) validate_credit_seat_type(self.course_seat_types)
def run_catalog_query(self, product):
"""
Retrieve the results from running the query contained in catalog_query field.
"""
request = get_current_request()
partner_code = request.site.siteconfiguration.partner.short_code
cache_key = get_cache_key(
site_domain=request.site.domain,
partner_code=partner_code,
resource='course_runs.contains',
course_id=product.course_id,
query=self.catalog_query
)
response = cache.get(cache_key)
if not response: # pragma: no cover
try:
response = request.site.siteconfiguration.discovery_api_client.course_runs.contains.get(
query=self.catalog_query,
course_run_ids=product.course_id,
partner=partner_code
)
cache.set(cache_key, response, settings.COURSES_API_CACHE_TIMEOUT)
except: # pylint: disable=bare-except
raise Exception('Could not contact Discovery Service.')
return response
def catalog_contains_product(self, product): def catalog_contains_product(self, product):
""" """
Retrieve the results from using the catalog contains endpoint for Retrieve the results from using the catalog contains endpoint for
...@@ -299,13 +350,6 @@ class Range(AbstractRange): ...@@ -299,13 +350,6 @@ class Range(AbstractRange):
# therefor an OR is used to check for both possibilities. # therefor an OR is used to check for both possibilities.
return ((response['courses'][product.course_id]) or return ((response['courses'][product.course_id]) or
super(Range, self).contains_product(product)) # pylint: disable=bad-super-call super(Range, self).contains_product(product)) # pylint: disable=bad-super-call
elif self.catalog_query and self.course_seat_types:
if product.attr.certificate_type.lower() in self.course_seat_types: # pylint: disable=unsupported-membership-test
response = self.run_catalog_query(product)
# Range can have a catalog query and 'regular' products in it,
# therefor an OR is used to check for both possibilities.
return ((response['course_runs'][product.course_id]) or
super(Range, self).contains_product(product)) # pylint: disable=bad-super-call
elif self.catalog: elif self.catalog:
return ( return (
product.id in self.catalog.stock_records.values_list('product', flat=True) or product.id in self.catalog.stock_records.values_list('product', flat=True) or
......
...@@ -3,16 +3,12 @@ from __future__ import unicode_literals ...@@ -3,16 +3,12 @@ from __future__ import unicode_literals
import ddt import ddt
import httpretty import httpretty
import mock
from django.core.cache import cache
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.test import RequestFactory
from oscar.core.loading import get_model from oscar.core.loading import get_model
from oscar.test import factories from oscar.test import factories
from requests.exceptions import ConnectionError, Timeout from requests.exceptions import ConnectionError, Timeout
from slumber.exceptions import SlumberBaseException from slumber.exceptions import SlumberBaseException
from ecommerce.core.utils import get_cache_key
from ecommerce.coupons.tests.mixins import CouponMixin, DiscoveryMockMixin from ecommerce.coupons.tests.mixins import CouponMixin, DiscoveryMockMixin
from ecommerce.extensions.catalogue.tests.mixins import DiscoveryTestMixin from ecommerce.extensions.catalogue.tests.mixins import DiscoveryTestMixin
from ecommerce.tests.testcases import TestCase from ecommerce.tests.testcases import TestCase
...@@ -92,62 +88,6 @@ class RangeTests(CouponMixin, DiscoveryTestMixin, DiscoveryMockMixin, TestCase): ...@@ -92,62 +88,6 @@ class RangeTests(CouponMixin, DiscoveryTestMixin, DiscoveryMockMixin, TestCase):
self.range.save() self.range.save()
self.assertEqual(self.range.catalog_query, large_query) self.assertEqual(self.range.catalog_query, large_query)
@mock.patch('ecommerce.core.url_utils.get_current_request', mock.Mock(return_value=None))
def test_run_catalog_query_no_request(self):
"""
run_course_query() should return status 400 response when no request is present.
"""
with self.assertRaises(Exception):
self.range.run_catalog_query(self.product)
def test_run_catalog_query(self):
"""
run_course_query() should return True for included course run ID's.
"""
course, seat = self.create_course_and_seat()
self.mock_access_token_response()
self.mock_course_runs_contains_endpoint(
query='key:*', course_run_ids=[course.id], discovery_api_url=self.site_configuration.discovery_api_url
)
request = RequestFactory()
request.site = self.site
self.range.catalog_query = 'key:*'
partner_code = request.site.siteconfiguration.partner.short_code
cache_key = get_cache_key(
site_domain=request.site.domain,
partner_code=partner_code,
resource='course_runs.contains',
course_id=seat.course_id,
query=self.range.catalog_query
)
cached_response = cache.get(cache_key)
self.assertIsNone(cached_response)
with mock.patch('ecommerce.core.url_utils.get_current_request', mock.Mock(return_value=request)):
response = self.range.run_catalog_query(seat)
self.assertTrue(response['course_runs'][course.id])
cached_response = cache.get(cache_key)
self.assertEqual(response, cached_response)
def test_query_range_contains_product(self):
"""
contains_product() should return the correct boolean if a product is in it's range.
"""
course, seat = self.create_course_and_seat()
self.mock_access_token_response()
self.mock_course_runs_contains_endpoint(
query='key:*', course_run_ids=[course.id], discovery_api_url=self.site_configuration.discovery_api_url
)
false_response = self.range.contains_product(seat)
self.assertFalse(false_response)
self.range.catalog_query = 'key:*'
self.range.course_seat_types = 'verified'
response = self.range.contains_product(seat)
self.assertTrue(response)
def test_course_catalog_query_range_contains_product(self): def test_course_catalog_query_range_contains_product(self):
""" """
Verify that the method "contains_product" returns True (boolean) if a Verify that the method "contains_product" returns True (boolean) if a
...@@ -156,9 +96,9 @@ class RangeTests(CouponMixin, DiscoveryTestMixin, DiscoveryMockMixin, TestCase): ...@@ -156,9 +96,9 @@ class RangeTests(CouponMixin, DiscoveryTestMixin, DiscoveryMockMixin, TestCase):
catalog_query = 'key:*' catalog_query = 'key:*'
course, seat = self.create_course_and_seat() course, seat = self.create_course_and_seat()
self.mock_access_token_response() self.mock_access_token_response()
self.mock_course_runs_contains_endpoint( self.mock_catalog_query_contains_endpoint(
query=catalog_query, course_run_ids=[course.id], query=catalog_query, course_run_ids=[course.id], course_uuids=[], absent_ids=[],
discovery_api_url=self.site_configuration.discovery_api_url discovery_api_url=self.site_configuration.discovery_api_url,
) )
false_response = self.range.contains_product(seat) false_response = self.range.contains_product(seat)
...@@ -356,7 +296,8 @@ class RangeTests(CouponMixin, DiscoveryTestMixin, DiscoveryMockMixin, TestCase): ...@@ -356,7 +296,8 @@ class RangeTests(CouponMixin, DiscoveryTestMixin, DiscoveryMockMixin, TestCase):
@ddt.ddt @ddt.ddt
class ConditionalOfferTests(TestCase): @httpretty.activate
class ConditionalOfferTests(DiscoveryTestMixin, DiscoveryMockMixin, TestCase):
"""Tests for custom ConditionalOffer model.""" """Tests for custom ConditionalOffer model."""
def setUp(self): def setUp(self):
super(ConditionalOfferTests, self).setUp() super(ConditionalOfferTests, self).setUp()
...@@ -396,6 +337,36 @@ class ConditionalOfferTests(TestCase): ...@@ -396,6 +337,36 @@ class ConditionalOfferTests(TestCase):
basket = self.create_basket(email='test@invalid.domain') basket = self.create_basket(email='test@invalid.domain')
self.assertFalse(self.offer.is_condition_satisfied(basket)) self.assertFalse(self.offer.is_condition_satisfied(basket))
def test_is_range_condition_satisfied(self):
"""
Verify that a basket satisfies a condition only when all of its products are in its range's catalog queryset.
"""
valid_user_email = 'valid@{domain}'.format(domain=self.valid_sub_domain)
basket = factories.BasketFactory(site=self.site, owner=factories.UserFactory(email=valid_user_email))
product = self.create_entitlement_product()
another_product = self.create_entitlement_product()
_range = factories.RangeFactory()
_range.course_seat_types = ','.join(Range.ALLOWED_SEAT_TYPES)
_range.catalog_query = 'uuid:{course_uuid}'.format(course_uuid=product.attr.UUID)
benefit = factories.BenefitFactory(range=_range)
offer = factories.ConditionalOfferFactory(benefit=benefit)
self.mock_access_token_response()
self.mock_catalog_query_contains_endpoint(
course_run_ids=[], course_uuids=[product.attr.UUID], absent_ids=[another_product.attr.UUID],
query=benefit.range.catalog_query, discovery_api_url=self.site_configuration.discovery_api_url
)
basket.add_product(product)
self.assertTrue(offer.is_condition_satisfied(basket))
basket.add_product(another_product)
self.assertFalse(offer.is_condition_satisfied(basket))
# Verify that API return values are cached
httpretty.disable()
self.assertFalse(offer.is_condition_satisfied(basket))
def test_is_email_valid(self): def test_is_email_valid(self):
"""Verify method returns True for valid emails.""" """Verify method returns True for valid emails."""
invalid_email = 'invalid@email.fake' invalid_email = 'invalid@email.fake'
...@@ -460,7 +431,18 @@ class ConditionalOfferTests(TestCase): ...@@ -460,7 +431,18 @@ class ConditionalOfferTests(TestCase):
self.assertEqual(offer.site, None) self.assertEqual(offer.site, None)
class BenefitTests(TestCase): class BenefitTests(DiscoveryTestMixin, DiscoveryMockMixin, TestCase):
def setUp(self):
super(BenefitTests, self).setUp()
_range = factories.RangeFactory(
course_seat_types=','.join(Range.ALLOWED_SEAT_TYPES[1:]),
catalog_query='uuid:*'
)
self.benefit = factories.BenefitFactory(range=_range)
self.offer = factories.ConditionalOfferFactory(benefit=self.benefit)
self.user = factories.UserFactory()
def test_range(self): def test_range(self):
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
factories.BenefitFactory(range=None) factories.BenefitFactory(range=None)
...@@ -471,3 +453,33 @@ class BenefitTests(TestCase): ...@@ -471,3 +453,33 @@ class BenefitTests(TestCase):
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
factories.BenefitFactory(value=-10) factories.BenefitFactory(value=-10)
@httpretty.activate
def test_get_applicable_lines(self):
""" Assert that basket lines matching the range's discovery query are selected. """
basket = factories.BasketFactory(site=self.site, owner=self.user)
entitlement_product = self.create_entitlement_product()
course, seat = self.create_course_and_seat()
no_certificate_product = factories.ProductFactory(stockrecords__price_currency='USD')
basket.add_product(entitlement_product)
basket.add_product(seat)
applicable_lines = list(basket.all_lines())
basket.add_product(no_certificate_product)
self.mock_access_token_response()
# Verify that the method raises an exception when it fails to reach the Discovery Service
with self.assertRaises(Exception):
self.benefit.get_applicable_lines(self.offer, basket)
self.mock_catalog_query_contains_endpoint(
course_run_ids=[], course_uuids=[entitlement_product.attr.UUID, course.id], absent_ids=[],
query=self.benefit.range.catalog_query, discovery_api_url=self.site_configuration.discovery_api_url
)
self.assertEqual(self.benefit.get_applicable_lines(self.offer, basket), applicable_lines)
# Verify that the API return value is cached
httpretty.disable()
self.assertEqual(self.benefit.get_applicable_lines(self.offer, basket), applicable_lines)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment