Commit b3676cd7 by David Ormsbee

Add replication tests

parent 17207220
...@@ -255,6 +255,11 @@ def add_user_to_default_group(user, group): ...@@ -255,6 +255,11 @@ def add_user_to_default_group(user, group):
utg.save() utg.save()
########################## REPLICATION SIGNALS ################################# ########################## REPLICATION SIGNALS #################################
@receiver(post_save, sender=User)
def replicate_user_save(sender, **kwargs):
user_obj = kwargs['instance']
return replicate_model(User.save, user_obj.id, **kwargs)
@receiver(post_save, sender=CourseEnrollment) @receiver(post_save, sender=CourseEnrollment)
def replicate_enrollment_save(sender, **kwargs): def replicate_enrollment_save(sender, **kwargs):
"""This is called when a Student enrolls in a course. It has to do the """This is called when a Student enrolls in a course. It has to do the
...@@ -266,6 +271,9 @@ def replicate_enrollment_save(sender, **kwargs): ...@@ -266,6 +271,9 @@ def replicate_enrollment_save(sender, **kwargs):
2. Replicate the CourseEnrollment. 2. Replicate the CourseEnrollment.
3. Replicate the UserProfile. 3. Replicate the UserProfile.
""" """
if not is_portal():
return
enrollment_obj = kwargs['instance'] enrollment_obj = kwargs['instance']
replicate_user(enrollment_obj.user, enrollment_obj.course_id) replicate_user(enrollment_obj.user, enrollment_obj.course_id)
replicate_model(CourseEnrollment.save, enrollment_obj.user_id, **kwargs) replicate_model(CourseEnrollment.save, enrollment_obj.user_id, **kwargs)
...@@ -281,23 +289,26 @@ def replicate_userprofile_save(sender, **kwargs): ...@@ -281,23 +289,26 @@ def replicate_userprofile_save(sender, **kwargs):
"""We just updated the UserProfile (say an update to the name), so push that """We just updated the UserProfile (say an update to the name), so push that
change to all Course DBs that we're enrolled in.""" change to all Course DBs that we're enrolled in."""
user_profile_obj = kwargs['instance'] user_profile_obj = kwargs['instance']
return replicate_model(UserProfile.save, enrollment_obj.user_id, **kwargs) return replicate_model(UserProfile.save, user_profile_obj.user_id, **kwargs)
######### Replication functions ######### ######### Replication functions #########
USER_FIELDS_TO_COPY = ["id", "username", "first_name", "last_name", "email",
"password", "is_staff", "is_active", "is_superuser",
"last_login", "date_joined"]
def replicate_user(portal_user, course_db_name): def replicate_user(portal_user, course_db_name):
"""Replicate a User to the correct Course DB. This is more complicated than """Replicate a User to the correct Course DB. This is more complicated than
it should be because Askbot extends the auth_user table and adds its own it should be because Askbot extends the auth_user table and adds its own
fields. So we need to only push changes to the standard fields and leave fields. So we need to only push changes to the standard fields and leave
the rest alone so that Askbot can the rest alone so that Askbot changes at the Course DB level don't get
overridden.
""" """
try: try:
# If the user exists in the Course DB, update the appropriate fields and # If the user exists in the Course DB, update the appropriate fields and
# save it back out to the Course DB. # save it back out to the Course DB.
course_user = User.objects.using(course_db_name).get(portal_user.id) course_user = User.objects.using(course_db_name).get(id=portal_user.id)
fields_to_copy = ["username", "first_name", "last_name", "email", for field in USER_FIELDS_TO_COPY:
"password", "is_staff", "is_active", "is_superuser",
"last_login", "date_joined"]
for field in fields_to_copy:
setattr(course_user, field, getattr(portal_user, field)) setattr(course_user, field, getattr(portal_user, field))
mark_handled(course_user) mark_handled(course_user)
...@@ -331,6 +342,8 @@ def is_valid_course_id(course_id): ...@@ -331,6 +342,8 @@ def is_valid_course_id(course_id):
"""We check to both make sure that it's a valid course_id (and not """We check to both make sure that it's a valid course_id (and not
'default', or some other non-course DB name) and that we have a mapping 'default', or some other non-course DB name) and that we have a mapping
for what database it belongs to.""" for what database it belongs to."""
return course_id != 'default'
course_ids = set(course.id for course in modulestore().get_courses()) course_ids = set(course.id for course in modulestore().get_courses())
is_valid = (course_id in course_ids) and (course_id in settings.DATABASES) is_valid = (course_id in course_ids) and (course_id in settings.DATABASES)
if not is_valid: if not is_valid:
...@@ -370,7 +383,8 @@ def should_replicate(instance): ...@@ -370,7 +383,8 @@ def should_replicate(instance):
the instance has to not have been marked_handled.""" the instance has to not have been marked_handled."""
if marked_handled(instance): if marked_handled(instance):
# Basically, avoid an infinite loop. You should # Basically, avoid an infinite loop. You should
log.debug("{0} should not be replicated because it's been marked") log.debug("{0} should not be replicated because it's been marked"
.format(instance))
return False return False
if not is_portal(): if not is_portal():
log.debug("{0} should not be replicated because we're not a portal." log.debug("{0} should not be replicated because we're not a portal."
......
...@@ -4,13 +4,167 @@ when you run "manage.py test". ...@@ -4,13 +4,167 @@ when you run "manage.py test".
Replace this with more appropriate tests for your application. Replace this with more appropriate tests for your application.
""" """
from datetime import datetime
from django.test import TestCase from django.test import TestCase
from .models import User, UserProfile, CourseEnrollment, replicate_user, USER_FIELDS_TO_COPY
COURSE_1 = 'edX/toy/2012_Fall'
COURSE_2 = 'edx/full/6.002_Spring_2012'
class ReplicationTest(TestCase):
multi_db = True
def test_user_replication(self):
"""Test basic user replication."""
portal_user = User.objects.create_user('rusty', 'rusty@edx.org', 'fakepass')
portal_user.first_name='Rusty'
portal_user.last_name='Skids'
portal_user.is_staff=True
portal_user.is_active=True
portal_user.is_superuser=True
portal_user.last_login=datetime(2012, 1, 1)
portal_user.date_joined=datetime(2011, 1, 1)
# This is an Askbot field and will break if askbot is not included
portal_user.seen_response_count = 10
portal_user.save(using='default')
# We replicate this user to Course 1, then pull the same user and verify
# that the fields copied over properly.
replicate_user(portal_user, COURSE_1)
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
# Make sure the fields we care about got copied over for this user.
for field in USER_FIELDS_TO_COPY:
self.assertEqual(getattr(portal_user, field),
getattr(course_user, field),
"{0} not copied from {1} to {2}".format(
field, portal_user, course_user
))
# Since it's the first copy over of User data, we should have all of it
self.assertEqual(portal_user.seen_response_count,
course_user.seen_response_count)
# But if we replicate again, the user already exists in the Course DB,
# so it shouldn't update the seen_response_count (which is Askbot
# controlled)
portal_user.seen_response_count = 20
replicate_user(portal_user, COURSE_1)
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEqual(portal_user.seen_response_count, 20)
self.assertEqual(course_user.seen_response_count, 10)
# Another replication should work for an email change however, since
# it's a field we care about.
portal_user.email = "clyde@edx.org"
replicate_user(portal_user, COURSE_1)
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEqual(portal_user.email, course_user.email)
# During this entire time, the user data should never have made it over
# to COURSE_2
self.assertRaises(User.DoesNotExist,
User.objects.using(COURSE_2).get,
id=portal_user.id)
def test_enrollment_for_existing_user_info(self):
"""Test the effect of Enrolling in a class if you've already got user
data to be copied over."""
# Create our User
portal_user = User.objects.create_user('jack', 'jack@edx.org', 'fakepass')
portal_user.first_name = "Jack"
portal_user.save()
# Set up our UserProfile info
portal_user_profile = UserProfile.objects.create(
user=portal_user,
name="Jack Foo",
level_of_education=None,
gender='m',
mailing_address=None,
goals="World domination",
)
portal_user_profile.save()
# Now let's see if creating a CourseEnrollment copies all the relevant
# data.
portal_enrollment = CourseEnrollment.objects.create(user=portal_user,
course_id=COURSE_1)
portal_enrollment.save()
# Grab all the copies we expect
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEquals(portal_user, course_user)
self.assertRaises(User.DoesNotExist,
User.objects.using(COURSE_2).get,
id=portal_user.id)
course_enrollment = CourseEnrollment.objects.using(COURSE_1).get(id=portal_enrollment.id)
self.assertEquals(portal_enrollment, course_enrollment)
self.assertRaises(CourseEnrollment.DoesNotExist,
CourseEnrollment.objects.using(COURSE_2).get,
id=portal_enrollment.id)
course_user_profile = UserProfile.objects.using(COURSE_1).get(id=portal_user_profile.id)
self.assertEquals(portal_user_profile, course_user_profile)
self.assertRaises(UserProfile.DoesNotExist,
UserProfile.objects.using(COURSE_2).get,
id=portal_user_profile.id)
def test_enrollment_for_user_info_after_enrollment(self):
"""Test the effect of Enrolling in a class if you've already got user
data to be copied over."""
# Create our User
portal_user = User.objects.create_user('jack', 'jack@edx.org', 'fakepass')
portal_user.first_name = "Jack"
portal_user.save()
# Now let's see if creating a CourseEnrollment copies all the relevant
# data when things are saved.
portal_enrollment = CourseEnrollment.objects.create(user=portal_user,
course_id=COURSE_1)
portal_enrollment.save()
# Set up our UserProfile info
portal_user_profile = UserProfile.objects.create(
user=portal_user,
name="Jack Foo",
level_of_education=None,
gender='m',
mailing_address=None,
goals="World domination",
)
portal_user_profile.save()
# Grab all the copies we expect, and make sure it doesn't end up in
# places we don't expect.
course_user = User.objects.using(COURSE_1).get(id=portal_user.id)
self.assertEquals(portal_user, course_user)
self.assertRaises(User.DoesNotExist,
User.objects.using(COURSE_2).get,
id=portal_user.id)
course_enrollment = CourseEnrollment.objects.using(COURSE_1).get(id=portal_enrollment.id)
self.assertEquals(portal_enrollment, course_enrollment)
self.assertRaises(CourseEnrollment.DoesNotExist,
CourseEnrollment.objects.using(COURSE_2).get,
id=portal_enrollment.id)
course_user_profile = UserProfile.objects.using(COURSE_1).get(id=portal_user_profile.id)
self.assertEquals(portal_user_profile, course_user_profile)
self.assertRaises(UserProfile.DoesNotExist,
UserProfile.objects.using(COURSE_2).get,
id=portal_user_profile.id)
class SimpleTest(TestCase):
def test_basic_addition(self):
"""
Tests that 1 + 1 always equals 2.
"""
self.assertEqual(1 + 1, 2)
...@@ -66,6 +66,17 @@ DATABASES = { ...@@ -66,6 +66,17 @@ DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.sqlite3',
'NAME': PROJECT_ROOT / "db" / "mitx.db", 'NAME': PROJECT_ROOT / "db" / "mitx.db",
},
# The following are for testing purposes...
'edX/toy/2012_Fall': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': ENV_ROOT / "db" / "course1.db",
},
'edx/full/6.002_Spring_2012': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': ENV_ROOT / "db" / "course2.db",
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment