Commit 44eaca08 by Eric Fischer

Merge pull request #10772 from edx/efischer/post_rc_fix

CohortMembership Transaction Fixes
parents 3b5d9fe9 88a38965
......@@ -34,19 +34,19 @@ class CohortedTestCase(SharedModuleStoreTestCase):
def setUp(self):
super(CohortedTestCase, self).setUp()
self.student_cohort = CohortFactory.create(
name="student_cohort",
course_id=self.course.id
)
self.moderator_cohort = CohortFactory.create(
name="moderator_cohort",
course_id=self.course.id
)
seed_permissions_roles(self.course.id)
self.student = UserFactory.create()
self.moderator = UserFactory.create()
CourseEnrollmentFactory(user=self.student, course_id=self.course.id)
CourseEnrollmentFactory(user=self.moderator, course_id=self.course.id)
self.moderator.roles.add(Role.objects.get(name="Moderator", course_id=self.course.id))
self.student_cohort.users.add(self.student)
self.moderator_cohort.users.add(self.moderator)
self.student_cohort = CohortFactory.create(
name="student_cohort",
course_id=self.course.id,
users=[self.student]
)
self.moderator_cohort = CohortFactory.create(
name="moderator_cohort",
course_id=self.course.id,
users=[self.moderator]
)
......@@ -133,8 +133,8 @@ class TestAnalyticsBasic(ModuleStoreTestCase):
course = CourseFactory.create(org="test", course="course1", display_name="run1")
course.cohort_config = {'cohorted': True, 'auto_cohort': True, 'auto_cohort_groups': ['cohort']}
self.store.update_item(course, self.instructor.id)
cohort = CohortFactory.create(name='cohort', course_id=course.id)
cohorted_students = [UserFactory.create() for _ in xrange(10)]
cohort = CohortFactory.create(name='cohort', course_id=course.id, users=cohorted_students)
cohorted_usernames = [student.username for student in cohorted_students]
non_cohorted_student = UserFactory.create()
for student in cohorted_students:
......
......@@ -1505,8 +1505,7 @@ def cohort_students_and_upload(_xmodule_instance_args, _entry_id, course_id, tas
continue
try:
with outer_atomic():
add_user_to_cohort(cohorts_status[cohort_name]['cohort'], username_or_email)
add_user_to_cohort(cohorts_status[cohort_name]['cohort'], username_or_email)
cohorts_status[cohort_name]['Students Added'] += 1
task_progress.succeeded += 1
except User.DoesNotExist:
......
......@@ -17,6 +17,7 @@ from xmodule.partitions.partitions import Group, UserPartition
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup
from openedx.core.djangoapps.course_groups.cohorts import add_user_to_cohort, remove_user_from_cohort
from ..testutils import MobileAPITestCase, MobileAuthTestMixin, MobileCourseAccessTestMixin
......@@ -744,7 +745,7 @@ class TestVideoSummaryList(
for cohort_index in range(len(cohorts)):
# add user to this cohort
cohorts[cohort_index].users.add(self.user)
add_user_to_cohort(cohorts[cohort_index], self.user.username)
# should only see video for this cohort
video_outline = self.api_response().data
......@@ -755,7 +756,7 @@ class TestVideoSummaryList(
)
# remove user from this cohort
cohorts[cohort_index].users.remove(self.user)
remove_user_from_cohort(cohorts[cohort_index], self.user.username)
# un-cohorted user should see no videos
video_outline = self.api_response().data
......
......@@ -5,7 +5,7 @@ from django.conf import settings
from django.test.client import RequestFactory
from django.test.utils import override_settings
from openedx.core.djangoapps.course_groups.models import CourseUserGroup
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from django_comment_common.models import Role, Permission
from lang_pref import LANGUAGE_KEY
from notification_prefs import NOTIFICATION_PREF_KEY
......@@ -46,12 +46,11 @@ class NotifierUsersViewSetTest(UrlResetMixin, ModuleStoreTestCase):
self.courses.append(course)
CourseEnrollmentFactory(user=self.user, course_id=course.id)
if is_user_cohorted:
cohort = CourseUserGroup.objects.create(
cohort = CohortFactory.create(
name="Test Cohort",
course_id=course.id,
group_type=CourseUserGroup.COHORT
users=[self.user]
)
cohort.users.add(self.user)
self.cohorts.append(cohort)
if is_moderator:
moderator_perm, _ = Permission.objects.get_or_create(name="see_all_cohorts")
......
......@@ -181,13 +181,12 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
# If course is cohorted, check if the user already has a cohort.
try:
cohort = CourseUserGroup.objects.get(
membership = CohortMembership.objects.get(
course_id=course_key,
group_type=CourseUserGroup.COHORT,
users__id=user.id,
user_id=user.id,
)
return request_cache.data.setdefault(cache_key, cohort)
except CourseUserGroup.DoesNotExist:
return request_cache.data.setdefault(cache_key, membership.course_user_group)
except CohortMembership.DoesNotExist:
# Didn't find the group. If we do not want to assign, return here.
if not assign:
# Do not cache the cohort here, because in the next call assign
......@@ -195,6 +194,17 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
return None
# Otherwise assign the user a cohort.
membership = CohortMembership.objects.create(
user=user,
course_user_group=_get_default_cohort(course_key)
)
return request_cache.data.setdefault(cache_key, membership.course_user_group)
def _get_default_cohort(course_key):
"""
Helper method to get a default cohort for assignment in get_cohort
"""
course = courses.get_course(course_key)
cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM)
if cohorts:
......@@ -205,11 +215,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
course_id=course_key,
assignment_type=CourseCohort.RANDOM
).course_user_group
membership = CohortMembership(course_user_group=cohort, user=user)
membership.save()
return request_cache.data.setdefault(cache_key, cohort)
return cohort
def migrate_cohort_settings(course):
......@@ -332,6 +338,27 @@ def is_cohort_exists(course_key, name):
return CourseUserGroup.objects.filter(course_id=course_key, group_type=CourseUserGroup.COHORT, name=name).exists()
def remove_user_from_cohort(cohort, username_or_email):
"""
Look up the given user, and if successful, remove them from the specified cohort.
Arguments:
cohort: CourseUserGroup
username_or_email: string. Treated as email if has '@'
Raises:
User.DoesNotExist if can't find user.
ValueError if user not already present in this cohort.
"""
user = get_user_by_username_or_email(username_or_email)
try:
membership = CohortMembership.objects.get(course_user_group=cohort, user=user)
membership.delete()
except CohortMembership.DoesNotExist:
raise ValueError("User {} was not present in cohort {}".format(username_or_email, cohort))
def add_user_to_cohort(cohort, username_or_email):
"""
Look up the given user, and if successful, add them to the specified cohort.
......@@ -350,7 +377,7 @@ def add_user_to_cohort(cohort, username_or_email):
user = get_user_by_username_or_email(username_or_email)
membership = CohortMembership(course_user_group=cohort, user=user)
membership.save()
membership.save() # This will handle both cases, creation and updating, of a CohortMembership for this user.
tracker.emit(
"edx.cohort.user_add_requested",
......
......@@ -7,7 +7,10 @@ import logging
from django.contrib.auth.models import User
from django.db import models, transaction, IntegrityError
from util.db import outer_atomic
from django.core.exceptions import ValidationError
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from xmodule_django.models import CourseKeyField
log = logging.getLogger(__name__)
......@@ -85,55 +88,63 @@ class CohortMembership(models.Model):
raise ValidationError("Non-matching course_ids provided")
def save(self, *args, **kwargs):
# Avoid infinite recursion if creating from get_or_create() call below.
if 'force_insert' in kwargs and kwargs['force_insert'] is True:
super(CohortMembership, self).save(*args, **kwargs)
return
self.full_clean(validate_unique=False)
# This loop has been created to allow for optimistic locking, and retrial in case of losing a race condition.
# The limit is 2, since select_for_update ensures atomic updates. Creation is the only possible race condition.
max_retries = 2
success = False
for __ in range(max_retries):
# Avoid infinite recursion if creating from get_or_create() call below.
# This block also allows middleware to use CohortMembership.get_or_create without worrying about outer_atomic
if 'force_insert' in kwargs and kwargs['force_insert'] is True:
with transaction.atomic():
try:
with transaction.atomic():
saved_membership, created = CohortMembership.objects.select_for_update().get_or_create(
user__id=self.user.id,
course_id=self.course_id,
defaults={
'course_user_group': self.course_user_group,
'user': self.user
}
)
except IntegrityError: # This can happen if simultaneous requests try to create a membership
continue
if not created:
if saved_membership.course_user_group == self.course_user_group:
raise ValueError("User {user_name} already present in cohort {cohort_name}".format(
user_name=self.user.username,
cohort_name=self.course_user_group.name
))
self.previous_cohort = saved_membership.course_user_group
self.previous_cohort_name = saved_membership.course_user_group.name
self.previous_cohort_id = saved_membership.course_user_group.id
self.previous_cohort.users.remove(self.user)
saved_membership.course_user_group = self.course_user_group
self.course_user_group.users.add(self.user)
self.course_user_group.save()
super(CohortMembership, self).save(*args, **kwargs)
return
super(CohortMembership, saved_membership).save(update_fields=['course_user_group'])
success = True
break
if not success:
raise IntegrityError("Unable to save membership after {} tries, aborting.".format(max_retries))
# This block will transactionally commit updates to CohortMembership and underlying course_user_groups.
# Note the use of outer_atomic, which guarantees that operations are committed to the database on block exit.
# If called from a view method, that method must be marked with @transaction.non_atomic_requests.
with outer_atomic(read_committed=True):
saved_membership, created = CohortMembership.objects.select_for_update().get_or_create(
user__id=self.user.id,
course_id=self.course_id,
defaults={
'course_user_group': self.course_user_group,
'user': self.user
}
)
# If the membership was newly created, all the validation and course_user_group logic was settled
# with a call to self.save(force_insert=True), which gets handled above.
if created:
return
if saved_membership.course_user_group == self.course_user_group:
raise ValueError("User {user_name} already present in cohort {cohort_name}".format(
user_name=self.user.username,
cohort_name=self.course_user_group.name
))
self.previous_cohort = saved_membership.course_user_group
self.previous_cohort_name = saved_membership.course_user_group.name
self.previous_cohort_id = saved_membership.course_user_group.id
self.previous_cohort.users.remove(self.user)
self.previous_cohort.save()
saved_membership.course_user_group = self.course_user_group
self.course_user_group.users.add(self.user)
self.course_user_group.save()
super(CohortMembership, saved_membership).save(update_fields=['course_user_group'])
# Needs to exist outside class definition in order to use 'sender=CohortMembership'
@receiver(pre_delete, sender=CohortMembership)
def remove_user_from_cohort(sender, instance, **kwargs): # pylint: disable=unused-argument
"""
Ensures that when a CohortMemebrship is deleted, the underlying CourseUserGroup
has its users list updated to reflect the change as well.
"""
instance.course_user_group.users.remove(instance.user)
instance.course_user_group.save()
class CourseUserGroupPartitionGroup(models.Model):
......
......@@ -32,6 +32,11 @@ class CohortFactory(DjangoModelFactory):
"""
if extracted:
self.users.add(*extracted)
for user in self.users.all():
CohortMembership.objects.create(
user=user,
course_user_group=self,
)
class CourseCohortFactory(DjangoModelFactory):
......@@ -41,18 +46,6 @@ class CourseCohortFactory(DjangoModelFactory):
class Meta(object):
model = CourseCohort
@post_generation
def memberships(self, create, extracted, **kwargs): # pylint: disable=unused-argument
"""
Returns the memberships linking users to this cohort.
"""
for user in self.course_user_group.users.all(): # pylint: disable=E1101
membership = CohortMembership(user=user, course_user_group=self.course_user_group)
membership.save()
course_user_group = factory.SubFactory(CohortFactory)
assignment_type = 'manual'
class CourseCohortSettingsFactory(DjangoModelFactory):
"""
......
......@@ -179,8 +179,7 @@ class TestCohorts(ModuleStoreTestCase):
self.assertIsNone(cohorts.get_cohort_id(user, course.id))
config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort")
cohort.users.add(user)
cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user])
self.assertEqual(cohorts.get_cohort_id(user, course.id), cohort.id)
self.assertRaises(
......@@ -237,8 +236,7 @@ class TestCohorts(ModuleStoreTestCase):
self.assertIsNone(cohorts.get_cohort(user, course.id), "No cohort created yet")
cohort = CohortFactory(course_id=course.id, name="TestCohort")
cohort.users.add(user)
cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user])
self.assertIsNone(
cohorts.get_cohort(user, course.id),
......@@ -261,8 +259,8 @@ class TestCohorts(ModuleStoreTestCase):
)
@ddt.data(
(True, 2),
(False, 6),
(True, 3),
(False, 9),
)
@ddt.unpack
def test_get_cohort_sql_queries(self, use_cached, num_sql_queries):
......@@ -271,10 +269,8 @@ class TestCohorts(ModuleStoreTestCase):
"""
course = modulestore().get_course(self.toy_course_key)
config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort")
user = UserFactory(username="test", email="a@b.com")
cohort.users.add(user)
CohortFactory.create(course_id=course.id, name="TestCohort", users=[user])
with self.assertNumQueries(num_sql_queries):
for __ in range(3):
......@@ -314,10 +310,7 @@ class TestCohorts(ModuleStoreTestCase):
user1 = UserFactory(username="test", email="a@b.com")
user2 = UserFactory(username="test2", email="a2@b.com")
cohort = CohortFactory(course_id=course.id, name="TestCohort")
# user1 manually added to a cohort
cohort.users.add(user1)
cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user1])
# Add an auto_cohort_group to the course...
config_course_cohorts(
......
......@@ -21,7 +21,7 @@ from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartiti
from ..partition_scheme import CohortPartitionScheme, get_cohorted_user_partition
from ..models import CourseUserGroupPartitionGroup
from ..views import link_cohort_to_partition_group, unlink_cohort_partition_group
from ..cohorts import add_user_to_cohort, get_course_cohorts
from ..cohorts import add_user_to_cohort, remove_user_from_cohort, get_course_cohorts
from .helpers import CohortFactory, config_course_cohorts
......@@ -100,7 +100,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase):
self.assert_student_in_group(self.groups[1])
# move the student out of the cohort
second_cohort.users.remove(self.student)
remove_user_from_cohort(second_cohort, self.student.username)
self.assert_student_in_group(None)
def test_cohort_partition_group_assignment(self):
......
......@@ -10,6 +10,7 @@ from django.core.urlresolvers import reverse
from django.http import Http404, HttpResponseBadRequest
from django.views.decorators.http import require_http_methods
from util.json_request import expect_json, JsonResponse
from django.db import transaction
from django.contrib.auth.decorators import login_required
from django.utils.translation import ugettext
......@@ -23,7 +24,7 @@ from edxmako.shortcuts import render_to_response
from . import cohorts
from lms.djangoapps.django_comment_client.utils import get_discussion_category_map, get_discussion_categories_ids
from .models import CourseUserGroup, CourseUserGroupPartitionGroup
from .models import CourseUserGroup, CourseUserGroupPartitionGroup, CohortMembership
log = logging.getLogger(__name__)
......@@ -299,6 +300,7 @@ def users_in_cohort(request, course_key_string, cohort_id):
'users': user_info})
@transaction.non_atomic_requests
@ensure_csrf_cookie
@require_POST
def add_users_to_cohort(request, course_key_string, cohort_id):
......@@ -384,16 +386,22 @@ def remove_user_from_cohort(request, course_key_string, cohort_id):
return json_http_response({'success': False,
'msg': 'No username specified'})
cohort = cohorts.get_cohort_by_id(course_key, cohort_id)
try:
user = User.objects.get(username=username)
cohort.users.remove(user)
return json_http_response({'success': True})
except User.DoesNotExist:
log.debug('no user')
return json_http_response({'success': False,
'msg': "No user '{0}'".format(username)})
try:
membership = CohortMembership.objects.get(user=user, course_id=course_key)
membership.delete()
except CohortMembership.DoesNotExist:
pass
return json_http_response({'success': True})
def debug_cohort_mgmt(request, course_key_string):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment