Commit 731d85f7 by Eric Fischer

CohortMembership Transaction Fixes

An issue arose recently due to ATOMIC_REQUESTS being turned on by default. It
turns out that CohortMemberships had been somewhat relying on the old default
transaction handling in order to keep CohortMemberships and the underlying
CourseUserGroup.users values in-sync.

To fix this, I've made all updates to Cohortmemberships go through an
outer_atomic(read_committed=True) block. This, is conjunction with the already
present select_for_update(), will no longer allow 2 simultaneous requests to
modify objects in memory without sharing them. Only one process will be
touching a given CohortMembership at any given time, and all changes will be
immediately comitted to the database, where the other process will see them.

I've also included some changes to get_cohort(), add_user_to_cohort(), and
remove_user_from_cohort() in order to properly make use of the new
CohortMembership system.
parent d370b0aa
...@@ -1505,7 +1505,6 @@ def cohort_students_and_upload(_xmodule_instance_args, _entry_id, course_id, tas ...@@ -1505,7 +1505,6 @@ def cohort_students_and_upload(_xmodule_instance_args, _entry_id, course_id, tas
continue continue
try: try:
with outer_atomic():
add_user_to_cohort(cohorts_status[cohort_name]['cohort'], username_or_email) add_user_to_cohort(cohorts_status[cohort_name]['cohort'], username_or_email)
cohorts_status[cohort_name]['Students Added'] += 1 cohorts_status[cohort_name]['Students Added'] += 1
task_progress.succeeded += 1 task_progress.succeeded += 1
......
...@@ -179,22 +179,30 @@ def get_cohort(user, course_key, assign=True, use_cached=False): ...@@ -179,22 +179,30 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
if not course_cohort_settings.is_cohorted: if not course_cohort_settings.is_cohorted:
return request_cache.data.setdefault(cache_key, None) return request_cache.data.setdefault(cache_key, None)
# If course is cohorted, check if the user already has a cohort.
try: try:
cohort = CourseUserGroup.objects.get( membership = CohortMembership.objects.get(
course_id=course_key, course_id=course_key,
group_type=CourseUserGroup.COHORT, user_id=user.id,
users__id=user.id,
) )
return request_cache.data.setdefault(cache_key, cohort) return request_cache.data.setdefault(cache_key, membership.course_user_group)
except CourseUserGroup.DoesNotExist: except CohortMembership.DoesNotExist:
# Didn't find the group. If we do not want to assign, return here. # Didn't find the group. If we do not want to assign, return here.
if not assign: if not assign:
# Do not cache the cohort here, because in the next call assign # Do not cache the cohort here, because in the next call assign
# may be True, and we will have to assign the user a cohort. # may be True, and we will have to assign the user a cohort.
return None return None
# Otherwise assign the user a cohort. membership = CohortMembership.objects.create(
user=user,
course_user_group=_get_default_cohort(course_key)
)
return request_cache.data.setdefault(cache_key, membership.course_user_group)
def _get_default_cohort(course_key):
"""
Helper method to get a default cohort for assignment in get_cohort
"""
course = courses.get_course(course_key) course = courses.get_course(course_key)
cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM) cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM)
if cohorts: if cohorts:
...@@ -205,11 +213,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False): ...@@ -205,11 +213,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
course_id=course_key, course_id=course_key,
assignment_type=CourseCohort.RANDOM assignment_type=CourseCohort.RANDOM
).course_user_group ).course_user_group
return cohort
membership = CohortMembership(course_user_group=cohort, user=user)
membership.save()
return request_cache.data.setdefault(cache_key, cohort)
def migrate_cohort_settings(course): def migrate_cohort_settings(course):
...@@ -350,7 +354,7 @@ def add_user_to_cohort(cohort, username_or_email): ...@@ -350,7 +354,7 @@ def add_user_to_cohort(cohort, username_or_email):
user = get_user_by_username_or_email(username_or_email) user = get_user_by_username_or_email(username_or_email)
membership = CohortMembership(course_user_group=cohort, user=user) membership = CohortMembership(course_user_group=cohort, user=user)
membership.save() membership.save() # This will handle both cases, creation and updating, of a CohortMembership for this user.
tracker.emit( tracker.emit(
"edx.cohort.user_add_requested", "edx.cohort.user_add_requested",
......
...@@ -7,7 +7,10 @@ import logging ...@@ -7,7 +7,10 @@ import logging
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import models, transaction, IntegrityError from django.db import models, transaction, IntegrityError
from util.db import outer_atomic
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from xmodule_django.models import CourseKeyField from xmodule_django.models import CourseKeyField
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -85,23 +88,22 @@ class CohortMembership(models.Model): ...@@ -85,23 +88,22 @@ class CohortMembership(models.Model):
raise ValidationError("Non-matching course_ids provided") raise ValidationError("Non-matching course_ids provided")
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
self.full_clean(validate_unique=False)
# Avoid infinite recursion if creating from get_or_create() call below. # Avoid infinite recursion if creating from get_or_create() call below.
# This block also allows middleware to use CohortMembership.get_or_create without worrying about outer_atomic
if 'force_insert' in kwargs and kwargs['force_insert'] is True: if 'force_insert' in kwargs and kwargs['force_insert'] is True:
with transaction.atomic():
self.course_user_group.users.add(self.user)
self.course_user_group.save()
super(CohortMembership, self).save(*args, **kwargs) super(CohortMembership, self).save(*args, **kwargs)
return return
self.full_clean(validate_unique=False) # This block will transactionally commit updates to CohortMembership and underlying course_user_groups.
# Note the use of outer_atomic, which guarantees that operations are committed to the database on block exit.
# This loop has been created to allow for optimistic locking, and retrial in case of losing a race condition. # If called from a view method, that method must be marked with @transaction.non_atomic_requests.
# The limit is 2, since select_for_update ensures atomic updates. Creation is the only possible race condition. with outer_atomic(read_committed=True):
max_retries = 2
success = False
for __ in range(max_retries):
with transaction.atomic():
try:
with transaction.atomic():
saved_membership, created = CohortMembership.objects.select_for_update().get_or_create( saved_membership, created = CohortMembership.objects.select_for_update().get_or_create(
user__id=self.user.id, user__id=self.user.id,
course_id=self.course_id, course_id=self.course_id,
...@@ -110,10 +112,12 @@ class CohortMembership(models.Model): ...@@ -110,10 +112,12 @@ class CohortMembership(models.Model):
'user': self.user 'user': self.user
} }
) )
except IntegrityError: # This can happen if simultaneous requests try to create a membership
continue
if not created: # If the membership was newly created, all the validation and course_user_group logic was settled
# with a call to self.save(force_insert=True), which gets handled above.
if created:
return
if saved_membership.course_user_group == self.course_user_group: if saved_membership.course_user_group == self.course_user_group:
raise ValueError("User {user_name} already present in cohort {cohort_name}".format( raise ValueError("User {user_name} already present in cohort {cohort_name}".format(
user_name=self.user.username, user_name=self.user.username,
...@@ -123,17 +127,24 @@ class CohortMembership(models.Model): ...@@ -123,17 +127,24 @@ class CohortMembership(models.Model):
self.previous_cohort_name = saved_membership.course_user_group.name self.previous_cohort_name = saved_membership.course_user_group.name
self.previous_cohort_id = saved_membership.course_user_group.id self.previous_cohort_id = saved_membership.course_user_group.id
self.previous_cohort.users.remove(self.user) self.previous_cohort.users.remove(self.user)
self.previous_cohort.save()
saved_membership.course_user_group = self.course_user_group saved_membership.course_user_group = self.course_user_group
self.course_user_group.users.add(self.user) self.course_user_group.users.add(self.user)
self.course_user_group.save()
super(CohortMembership, saved_membership).save(update_fields=['course_user_group']) super(CohortMembership, saved_membership).save(update_fields=['course_user_group'])
success = True
break
if not success: # Needs to exist outside class definition in order to use 'sender=CohortMembership'
raise IntegrityError("Unable to save membership after {} tries, aborting.".format(max_retries)) @receiver(pre_delete, sender=CohortMembership)
def remove_user_from_cohort(sender, instance, **kwargs): # pylint: disable=unused-argument
"""
Ensures that when a CohortMemebrship is deleted, the underlying CourseUserGroup
has its users list updated to reflect the change as well.
"""
instance.course_user_group.users.remove(instance.user)
instance.course_user_group.save()
class CourseUserGroupPartitionGroup(models.Model): class CourseUserGroupPartitionGroup(models.Model):
......
...@@ -10,6 +10,8 @@ from django.core.urlresolvers import reverse ...@@ -10,6 +10,8 @@ from django.core.urlresolvers import reverse
from django.http import Http404, HttpResponseBadRequest from django.http import Http404, HttpResponseBadRequest
from django.views.decorators.http import require_http_methods from django.views.decorators.http import require_http_methods
from util.json_request import expect_json, JsonResponse from util.json_request import expect_json, JsonResponse
from util.db import outer_atomic
from django.db import transaction
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.utils.translation import ugettext from django.utils.translation import ugettext
...@@ -23,7 +25,7 @@ from edxmako.shortcuts import render_to_response ...@@ -23,7 +25,7 @@ from edxmako.shortcuts import render_to_response
from . import cohorts from . import cohorts
from lms.djangoapps.django_comment_client.utils import get_discussion_category_map, get_discussion_categories_ids from lms.djangoapps.django_comment_client.utils import get_discussion_category_map, get_discussion_categories_ids
from .models import CourseUserGroup, CourseUserGroupPartitionGroup from .models import CourseUserGroup, CourseUserGroupPartitionGroup, CohortMembership
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -299,6 +301,7 @@ def users_in_cohort(request, course_key_string, cohort_id): ...@@ -299,6 +301,7 @@ def users_in_cohort(request, course_key_string, cohort_id):
'users': user_info}) 'users': user_info})
@transaction.non_atomic_requests
@ensure_csrf_cookie @ensure_csrf_cookie
@require_POST @require_POST
def add_users_to_cohort(request, course_key_string, cohort_id): def add_users_to_cohort(request, course_key_string, cohort_id):
...@@ -384,16 +387,22 @@ def remove_user_from_cohort(request, course_key_string, cohort_id): ...@@ -384,16 +387,22 @@ def remove_user_from_cohort(request, course_key_string, cohort_id):
return json_http_response({'success': False, return json_http_response({'success': False,
'msg': 'No username specified'}) 'msg': 'No username specified'})
cohort = cohorts.get_cohort_by_id(course_key, cohort_id)
try: try:
user = User.objects.get(username=username) user = User.objects.get(username=username)
cohort.users.remove(user)
return json_http_response({'success': True})
except User.DoesNotExist: except User.DoesNotExist:
log.debug('no user') log.debug('no user')
return json_http_response({'success': False, return json_http_response({'success': False,
'msg': "No user '{0}'".format(username)}) 'msg': "No user '{0}'".format(username)})
try:
membership = CohortMembership.objects.get(user=user, course_id=course_key)
membership.delete()
except CohortMembership.DoesNotExist:
pass
return json_http_response({'success': True})
def debug_cohort_mgmt(request, course_key_string): def debug_cohort_mgmt(request, course_key_string):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment