Commit 76c4259c by Calen Pennington

Merge pull request #399 from MITx/feature/fix_replication_side_effects

Feature/fix replication side effects
parents add05059 c4d89cd5
...@@ -257,8 +257,11 @@ def add_user_to_default_group(user, group): ...@@ -257,8 +257,11 @@ def add_user_to_default_group(user, group):
########################## REPLICATION SIGNALS ################################# ########################## REPLICATION SIGNALS #################################
@receiver(post_save, sender=User) @receiver(post_save, sender=User)
def replicate_user_save(sender, **kwargs): def replicate_user_save(sender, **kwargs):
user_obj = kwargs['instance'] user_obj = kwargs['instance']
return replicate_model(User.save, user_obj, user_obj.id) if not should_replicate(user_obj):
return
for course_db_name in db_names_to_replicate_to(user_obj.id):
replicate_user(user_obj, course_db_name)
@receiver(post_save, sender=CourseEnrollment) @receiver(post_save, sender=CourseEnrollment)
def replicate_enrollment_save(sender, **kwargs): def replicate_enrollment_save(sender, **kwargs):
...@@ -287,8 +290,8 @@ def replicate_enrollment_save(sender, **kwargs): ...@@ -287,8 +290,8 @@ def replicate_enrollment_save(sender, **kwargs):
@receiver(post_delete, sender=CourseEnrollment) @receiver(post_delete, sender=CourseEnrollment)
def replicate_enrollment_delete(sender, **kwargs): def replicate_enrollment_delete(sender, **kwargs):
enrollment_obj = kwargs['instance'] enrollment_obj = kwargs['instance']
return replicate_model(CourseEnrollment.delete, enrollment_obj, enrollment_obj.user_id) return replicate_model(CourseEnrollment.delete, enrollment_obj, enrollment_obj.user_id)
@receiver(post_save, sender=UserProfile) @receiver(post_save, sender=UserProfile)
def replicate_userprofile_save(sender, **kwargs): def replicate_userprofile_save(sender, **kwargs):
...@@ -311,23 +314,20 @@ def replicate_user(portal_user, course_db_name): ...@@ -311,23 +314,20 @@ def replicate_user(portal_user, course_db_name):
overridden. overridden.
""" """
try: try:
# If the user exists in the Course DB, update the appropriate fields and
# save it back out to the Course DB.
course_user = User.objects.using(course_db_name).get(id=portal_user.id) course_user = User.objects.using(course_db_name).get(id=portal_user.id)
for field in USER_FIELDS_TO_COPY:
setattr(course_user, field, getattr(portal_user, field))
mark_handled(course_user)
log.debug("User {0} found in Course DB, replicating fields to {1}" log.debug("User {0} found in Course DB, replicating fields to {1}"
.format(course_user, course_db_name)) .format(course_user, course_db_name))
course_user.save(using=course_db_name) # Just being explicit.
except User.DoesNotExist: except User.DoesNotExist:
# Otherwise, just make a straight copy to the Course DB.
mark_handled(portal_user)
log.debug("User {0} not found in Course DB, creating copy in {1}" log.debug("User {0} not found in Course DB, creating copy in {1}"
.format(portal_user, course_db_name)) .format(portal_user, course_db_name))
portal_user.save(using=course_db_name) course_user = User()
for field in USER_FIELDS_TO_COPY:
setattr(course_user, field, getattr(portal_user, field))
mark_handled(course_user)
course_user.save(using=course_db_name)
unmark(course_user)
def replicate_model(model_method, instance, user_id): def replicate_model(model_method, instance, user_id):
""" """
...@@ -337,13 +337,14 @@ def replicate_model(model_method, instance, user_id): ...@@ -337,13 +337,14 @@ def replicate_model(model_method, instance, user_id):
if not should_replicate(instance): if not should_replicate(instance):
return return
mark_handled(instance)
course_db_names = db_names_to_replicate_to(user_id) course_db_names = db_names_to_replicate_to(user_id)
log.debug("Replicating {0} for user {1} to DBs: {2}" log.debug("Replicating {0} for user {1} to DBs: {2}"
.format(model_method, user_id, course_db_names)) .format(model_method, user_id, course_db_names))
mark_handled(instance)
for db_name in course_db_names: for db_name in course_db_names:
model_method(instance, using=db_name) model_method(instance, using=db_name)
unmark(instance)
######### Replication Helpers ######### ######### Replication Helpers #########
...@@ -371,7 +372,7 @@ def db_names_to_replicate_to(user_id): ...@@ -371,7 +372,7 @@ def db_names_to_replicate_to(user_id):
def marked_handled(instance): def marked_handled(instance):
"""Have we marked this instance as being handled to avoid infinite loops """Have we marked this instance as being handled to avoid infinite loops
caused by saving models in post_save hooks for the same models?""" caused by saving models in post_save hooks for the same models?"""
return hasattr(instance, '_do_not_copy_to_course_db') return hasattr(instance, '_do_not_copy_to_course_db') and instance._do_not_copy_to_course_db
def mark_handled(instance): def mark_handled(instance):
"""You have to mark your instance with this function or else we'll go into """You have to mark your instance with this function or else we'll go into
...@@ -384,6 +385,11 @@ def mark_handled(instance): ...@@ -384,6 +385,11 @@ def mark_handled(instance):
""" """
instance._do_not_copy_to_course_db = True instance._do_not_copy_to_course_db = True
def unmark(instance):
"""If we don't unmark a model after we do replication, then consecutive
save() calls won't be properly replicated."""
instance._do_not_copy_to_course_db = False
def should_replicate(instance): def should_replicate(instance):
"""Should this instance be replicated? We need to be a Portal server and """Should this instance be replicated? We need to be a Portal server and
the instance has to not have been marked_handled.""" the instance has to not have been marked_handled."""
...@@ -398,9 +404,3 @@ def should_replicate(instance): ...@@ -398,9 +404,3 @@ def should_replicate(instance):
return False return False
return True return True
...@@ -4,6 +4,7 @@ when you run "manage.py test". ...@@ -4,6 +4,7 @@ when you run "manage.py test".
Replace this with more appropriate tests for your application. Replace this with more appropriate tests for your application.
""" """
import logging
from datetime import datetime from datetime import datetime
from django.test import TestCase from django.test import TestCase
...@@ -13,6 +14,8 @@ from .models import User, UserProfile, CourseEnrollment, replicate_user, USER_FI ...@@ -13,6 +14,8 @@ from .models import User, UserProfile, CourseEnrollment, replicate_user, USER_FI
COURSE_1 = 'edX/toy/2012_Fall' COURSE_1 = 'edX/toy/2012_Fall'
COURSE_2 = 'edx/full/6.002_Spring_2012' COURSE_2 = 'edx/full/6.002_Spring_2012'
log = logging.getLogger(__name__)
class ReplicationTest(TestCase): class ReplicationTest(TestCase):
multi_db = True multi_db = True
...@@ -47,23 +50,18 @@ class ReplicationTest(TestCase): ...@@ -47,23 +50,18 @@ class ReplicationTest(TestCase):
field, portal_user, course_user field, portal_user, course_user
)) ))
if hasattr(portal_user, 'seen_response_count'):
# Since it's the first copy over of User data, we should have all of it
self.assertEqual(portal_user.seen_response_count,
course_user.seen_response_count)
# But if we replicate again, the user already exists in the Course DB,
# so it shouldn't update the seen_response_count (which is Askbot
# controlled).
# This hasattr lameness is here because we don't want this test to be # This hasattr lameness is here because we don't want this test to be
# triggered when we're being run by CMS tests (Askbot doesn't exist # triggered when we're being run by CMS tests (Askbot doesn't exist
# there, so the test will fail). # there, so the test will fail).
#
# seen_response_count isn't a field we care about, so it shouldn't have
# been copied over.
if hasattr(portal_user, 'seen_response_count'): if hasattr(portal_user, 'seen_response_count'):
portal_user.seen_response_count = 20 portal_user.seen_response_count = 20
replicate_user(portal_user, COURSE_1) replicate_user(portal_user, COURSE_1)
course_user = User.objects.using(COURSE_1).get(id=portal_user.id) course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEqual(portal_user.seen_response_count, 20) self.assertEqual(portal_user.seen_response_count, 20)
self.assertEqual(course_user.seen_response_count, 10) self.assertEqual(course_user.seen_response_count, 0)
# Another replication should work for an email change however, since # Another replication should work for an email change however, since
# it's a field we care about. # it's a field we care about.
...@@ -123,6 +121,25 @@ class ReplicationTest(TestCase): ...@@ -123,6 +121,25 @@ class ReplicationTest(TestCase):
UserProfile.objects.using(COURSE_2).get, UserProfile.objects.using(COURSE_2).get,
id=portal_user_profile.id) id=portal_user_profile.id)
log.debug("Make sure our seen_response_count is not replicated.")
if hasattr(portal_user, 'seen_response_count'):
portal_user.seen_response_count = 200
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEqual(portal_user.seen_response_count, 200)
self.assertEqual(course_user.seen_response_count, 0)
portal_user.save()
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEqual(portal_user.seen_response_count, 200)
self.assertEqual(course_user.seen_response_count, 0)
portal_user.email = 'jim@edx.org'
portal_user.save()
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEqual(portal_user.email, 'jim@edx.org')
self.assertEqual(course_user.email, 'jim@edx.org')
def test_enrollment_for_user_info_after_enrollment(self): def test_enrollment_for_user_info_after_enrollment(self):
"""Test the effect of modifying User data after you've enrolled.""" """Test the effect of modifying User data after you've enrolled."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment