Commit 992d83e6 by Clinton Blackburn

Merge pull request #224 from edx/migration-fix

Updated Course Migration and Publishing
parents c605f62a d0841be7
...@@ -5,9 +5,7 @@ from django.db import models, transaction ...@@ -5,9 +5,7 @@ from django.db import models, transaction
from django.utils.text import slugify from django.utils.text import slugify
from oscar.core.loading import get_model from oscar.core.loading import get_model
from simple_history.models import HistoricalRecords from simple_history.models import HistoricalRecords
import waffle
from ecommerce.courses.exceptions import PublishFailed
from ecommerce.courses.publishers import LMSPublisher from ecommerce.courses.publishers import LMSPublisher
from ecommerce.extensions.catalogue.utils import generate_sku from ecommerce.extensions.catalogue.utils import generate_sku
...@@ -48,20 +46,11 @@ class Course(models.Model): ...@@ -48,20 +46,11 @@ class Course(models.Model):
else: else:
logger.debug('Parent seat [%d] already exists for [%s].', parent.id, self.id) logger.debug('Parent seat [%d] already exists for [%s].', parent.id, self.id)
# pylint: disable=arguments-differ
@transaction.atomic @transaction.atomic
def save(self, force_insert=False, force_update=False, using=None, update_fields=None, publish=True): def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
super(Course, self).save(force_insert, force_update, using, update_fields) super(Course, self).save(force_insert, force_update, using, update_fields)
self._create_parent_seat() self._create_parent_seat()
if publish and waffle.switch_is_active('publish_course_modes_to_lms'):
if not self.publish_to_lms():
# Raise an exception to force a rollback
raise PublishFailed('Failed to publish {}'.format(self.id))
else:
logger.debug('Course mode publishing is not enabled. Commerce changes will not be published!')
def publish_to_lms(self): def publish_to_lms(self):
""" Publish Course and Products to LMS. """ """ Publish Course and Products to LMS. """
return LMSPublisher().publish(self) return LMSPublisher().publish(self)
...@@ -125,13 +114,22 @@ class Course(models.Model): ...@@ -125,13 +114,22 @@ class Course(models.Model):
certificate_type = certificate_type.lower() certificate_type = certificate_type.lower()
course_id = unicode(self.id) course_id = unicode(self.id)
slugs = []
slug = 'child-cs-{}-{}'.format(certificate_type, slugify(course_id)) slug = 'child-cs-{}-{}'.format(certificate_type, slugify(course_id))
# Note (CCB): Our previous method of slug generation did not account for ID verification. By using a list
# we can update these seats. This should be removed after the courses have been re-migrated.
if certificate_type == 'verified':
slugs.append(slug)
if id_verification_required: if id_verification_required:
slug += '-id-verified' slug += '-id-verified'
slugs.append(slug)
slugs = set(slugs)
try: try:
seat = Product.objects.get(slug=slug) seat = Product.objects.get(slug__in=slugs)
logger.info('Retrieved [%s] course seat child product for [%s] from database.', certificate_type, logger.info('Retrieved [%s] course seat child product for [%s] from database.', certificate_type,
course_id) course_id)
except Product.DoesNotExist: except Product.DoesNotExist:
......
...@@ -14,6 +14,12 @@ class LMSPublisher(object): ...@@ -14,6 +14,12 @@ class LMSPublisher(object):
return 'no-id-professional' return 'no-id-professional'
return seat.attr.certificate_type return seat.attr.certificate_type
def get_seat_expiration(self, seat):
if not seat.expires or 'professional' in seat.attr.certificate_type:
return None
return seat.expires.isoformat()
def serialize_seat_for_commerce_api(self, seat): def serialize_seat_for_commerce_api(self, seat):
""" Serializes a course seat product to a dict that can be further serialized to JSON. """ """ Serializes a course seat product to a dict that can be further serialized to JSON. """
stock_record = seat.stockrecords.first() stock_record = seat.stockrecords.first()
...@@ -22,7 +28,7 @@ class LMSPublisher(object): ...@@ -22,7 +28,7 @@ class LMSPublisher(object):
'currency': stock_record.price_currency, 'currency': stock_record.price_currency,
'price': int(stock_record.price_excl_tax), 'price': int(stock_record.price_excl_tax),
'sku': stock_record.partner_sku, 'sku': stock_record.partner_sku,
'expires': seat.expires.isoformat() if seat.expires else None, 'expires': self.get_seat_expiration(seat),
} }
def publish(self, course): def publish(self, course):
......
import ddt import ddt
from django.test import TestCase from django.test import TestCase
from django_dynamic_fixture import G, N from django_dynamic_fixture import G
import mock import mock
from oscar.core.loading import get_model from oscar.core.loading import get_model
from testfixtures import LogCapture
from waffle import Switch
from ecommerce.courses.models import Course from ecommerce.courses.models import Course
from ecommerce.courses.publishers import LMSPublisher from ecommerce.courses.publishers import LMSPublisher
...@@ -76,56 +74,6 @@ class CourseTests(CourseCatalogTestMixin, TestCase): ...@@ -76,56 +74,6 @@ class CourseTests(CourseCatalogTestMixin, TestCase):
course.publish_to_lms() course.publish_to_lms()
self.assertTrue(mock_publish.called) self.assertTrue(mock_publish.called)
def test_save_and_publish_to_lms(self):
""" Verify the save method calls publish_to_lms if the feature is enabled. """
switch, __ = Switch.objects.get_or_create(name='publish_course_modes_to_lms', active=False)
course = G(Course)
with mock.patch.object(Course, 'publish_to_lms') as mock_publish:
logger_name = 'ecommerce.courses.models'
with LogCapture(logger_name) as l:
course.save()
l.check(
(logger_name, 'DEBUG',
'Parent seat [{}] already exists for [{}].'.format(course.parent_seat_product.id, course.id)),
(logger_name, 'DEBUG',
'Course mode publishing is not enabled. Commerce changes will not be published!')
)
self.assertFalse(mock_publish.called)
# Reset the mock and activate the feature.
mock_publish.reset_mock()
switch.active = True
switch.save()
# With the feature active, the mock method should be called.
course.save()
self.assertTrue(mock_publish.called)
def test_save_with_publish_failure(self):
""" Verify that, if the publish operation fails, the model's changes are not saved to the database. """
orignal_name = 'A Most Awesome Course'
course = G(Course, name=orignal_name)
Switch.objects.get_or_create(name='publish_course_modes_to_lms', active=True)
# Mock an error in the publisher
with mock.patch.object(LMSPublisher, 'publish', return_value=False):
course.name = 'An Okay Course'
# Reload the course from the database
course = Course.objects.get(id=course.id)
self.assertEqual(course.name, orignal_name)
def test_save_without_publish(self):
""" Verify the Course is not published to LMS if the publish kwarg is set to False. """
Switch.objects.get_or_create(name='publish_course_modes_to_lms', active=False)
course = N(Course)
with mock.patch.object(LMSPublisher, 'publish') as mock_publish:
course.save(publish=False)
self.assertFalse(mock_publish.called)
def test_save_creates_parent_seat(self): def test_save_creates_parent_seat(self):
""" Verify the save method creates a parent seat if one does not exist. """ """ Verify the save method creates a parent seat if one does not exist. """
course = Course.objects.create(id='a/b/c', name='Test Course') course = Course.objects.create(id='a/b/c', name='Test Course')
......
...@@ -115,3 +115,25 @@ class LMSPublisherTests(CourseCatalogTestMixin, TestCase): ...@@ -115,3 +115,25 @@ class LMSPublisherTests(CourseCatalogTestMixin, TestCase):
expected['expires'] = expires.isoformat() expected['expires'] = expires.isoformat()
actual = self.publisher.serialize_seat_for_commerce_api(seat) actual = self.publisher.serialize_seat_for_commerce_api(seat)
self.assertDictEqual(actual, expected) self.assertDictEqual(actual, expected)
@ddt.unpack
@ddt.data(
(True, 'professional'),
(False, 'no-id-professional'),
)
def test_serialize_seat_for_commerce_api_with_professional(self, is_verified, expected_mode):
"""
Verify that (a) professional seats NEVER have an expiration date and (b) the name/mode is properly set for
no-id-professional seats.
"""
seat = self.course.create_or_update_seat('professional', is_verified, 500, expires=datetime.datetime.utcnow())
stock_record = seat.stockrecords.first()
actual = self.publisher.serialize_seat_for_commerce_api(seat)
expected = {
'name': expected_mode,
'currency': 'USD',
'price': int(stock_record.price_excl_tax),
'sku': stock_record.partner_sku,
'expires': None
}
self.assertDictEqual(actual, expected)
...@@ -7,6 +7,7 @@ from django.conf import settings ...@@ -7,6 +7,7 @@ from django.conf import settings
from django.core.management import BaseCommand from django.core.management import BaseCommand
from django.db import transaction from django.db import transaction
import requests import requests
import waffle
from ecommerce.courses.models import Course from ecommerce.courses.models import Course
...@@ -15,12 +16,7 @@ logger = logging.getLogger(__name__) ...@@ -15,12 +16,7 @@ logger = logging.getLogger(__name__)
class MigratedCourse(object): class MigratedCourse(object):
def __init__(self, course_id): def __init__(self, course_id):
# Avoid use of get_or_create to prevent publication to the self.course, _created = Course.objects.get_or_create(id=course_id)
# LMS when saving the newly instantiated Course.
try:
self.course = Course.objects.get(id=course_id)
except Course.DoesNotExist:
self.course = Course(id=course_id)
def load_from_lms(self, access_token): def load_from_lms(self, access_token):
""" """
...@@ -30,7 +26,7 @@ class MigratedCourse(object): ...@@ -30,7 +26,7 @@ class MigratedCourse(object):
""" """
name, modes = self._retrieve_data_from_lms(access_token) name, modes = self._retrieve_data_from_lms(access_token)
self.course.name = name self.course.name = name
self.course.save(publish=False) self.course.save()
self._get_products(modes) self._get_products(modes)
def _build_lms_url(self, path): def _build_lms_url(self, path):
...@@ -118,6 +114,7 @@ class Command(BaseCommand): ...@@ -118,6 +114,7 @@ class Command(BaseCommand):
course = migrated_course.course course = migrated_course.course
msg = 'Retrieved info for {0} ({1}):\n'.format(course.id, course.name) msg = 'Retrieved info for {0} ({1}):\n'.format(course.id, course.name)
msg += '\t(cert. type, verified?, price, SKU, slug, expires)\n'
for seat in course.seat_products: for seat in course.seat_products:
stock_record = seat.stockrecords.first() stock_record = seat.stockrecords.first()
...@@ -129,10 +126,14 @@ class Command(BaseCommand): ...@@ -129,10 +126,14 @@ class Command(BaseCommand):
logger.info(msg) logger.info(msg)
if options.get('commit', False): if options.get('commit', False):
logger.info('Course [%s] was saved to the database.', migrated_course.course.id) logger.info('Course [%s] was saved to the database.', course.id)
transaction.commit() if waffle.switch_is_active('publish_course_modes_to_lms'):
course.publish_to_lms()
else:
logger.info('Data was not published to LMS because the switch '
'[publish_course_modes_to_lms] is disabled.')
else: else:
logger.info('Course [%s] was NOT saved to the database.', migrated_course.course.id) logger.info('Course [%s] was NOT saved to the database.', course.id)
raise Exception('Forced rollback.') raise Exception('Forced rollback.')
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception('Failed to migrate [%s]!', course_id) logger.exception('Failed to migrate [%s]!', course_id)
...@@ -177,10 +177,9 @@ class CommandTests(CourseMigrationTestMixin, TestCase): ...@@ -177,10 +177,9 @@ class CommandTests(CourseMigrationTestMixin, TestCase):
self._mock_lms_api() self._mock_lms_api()
with mock.patch.object(LMSPublisher, 'publish') as mock_publish: with mock.patch.object(LMSPublisher, 'publish') as mock_publish:
mock_publish.return_value = True
call_command('migrate_course', self.course_id, access_token=ACCESS_TOKEN, commit=True) call_command('migrate_course', self.course_id, access_token=ACCESS_TOKEN, commit=True)
# Verify that the migrated course was not published back to the LMS # Verify that the migrated course was published back to the LMS
self.assertFalse(mock_publish.called) self.assertTrue(mock_publish.called)
self.assert_course_migrated() self.assert_course_migrated()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment