Commit 83163e58 by Eric Fischer

CohortMemberships Code Changes

The code changes needed to get CohortMembership functioning properly.

The key of this change is twofold: first, CohortMemberships are unique
per-user, per-course. This is enforced at the database level. Secondly,
the updates are done using a select_for_update, which ensures atomicity.
parent 56d49e73
...@@ -6,7 +6,6 @@ forums, and to the cohort admin views. ...@@ -6,7 +6,6 @@ forums, and to the cohort admin views.
import logging import logging
import random import random
from django.db import transaction
from django.db.models.signals import post_save, m2m_changed from django.db.models.signals import post_save, m2m_changed
from django.dispatch import receiver from django.dispatch import receiver
from django.http import Http404 from django.http import Http404
...@@ -17,7 +16,13 @@ from eventtracking import tracker ...@@ -17,7 +16,13 @@ from eventtracking import tracker
from request_cache.middleware import RequestCache from request_cache.middleware import RequestCache
from student.models import get_user_by_username_or_email from student.models import get_user_by_username_or_email
from .models import CourseUserGroup, CourseCohort, CourseCohortsSettings, CourseUserGroupPartitionGroup from .models import (
CourseUserGroup,
CourseCohort,
CourseCohortsSettings,
CourseUserGroupPartitionGroup,
CohortMembership
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -140,7 +145,6 @@ def get_cohorted_commentables(course_key): ...@@ -140,7 +145,6 @@ def get_cohorted_commentables(course_key):
return ans return ans
@transaction.commit_on_success
def get_cohort(user, course_key, assign=True, use_cached=False): def get_cohort(user, course_key, assign=True, use_cached=False):
"""Returns the user's cohort for the specified course. """Returns the user's cohort for the specified course.
...@@ -202,7 +206,8 @@ def get_cohort(user, course_key, assign=True, use_cached=False): ...@@ -202,7 +206,8 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
assignment_type=CourseCohort.RANDOM assignment_type=CourseCohort.RANDOM
).course_user_group ).course_user_group
user.course_groups.add(cohort) membership = CohortMembership(course_user_group=cohort, user=user)
membership.save()
return request_cache.data.setdefault(cache_key, cohort) return request_cache.data.setdefault(cache_key, cohort)
...@@ -343,25 +348,9 @@ def add_user_to_cohort(cohort, username_or_email): ...@@ -343,25 +348,9 @@ def add_user_to_cohort(cohort, username_or_email):
ValueError if user already present in this cohort. ValueError if user already present in this cohort.
""" """
user = get_user_by_username_or_email(username_or_email) user = get_user_by_username_or_email(username_or_email)
previous_cohort_name = None
previous_cohort_id = None
course_cohorts = CourseUserGroup.objects.filter( membership = CohortMembership(course_user_group=cohort, user=user)
course_id=cohort.course_id, membership.save()
users__id=user.id,
group_type=CourseUserGroup.COHORT
)
if course_cohorts.exists():
if course_cohorts[0] == cohort:
raise ValueError("User {user_name} already present in cohort {cohort_name}".format(
user_name=user.username,
cohort_name=cohort.name
))
else:
previous_cohort = course_cohorts[0]
previous_cohort.users.remove(user)
previous_cohort_name = previous_cohort.name
previous_cohort_id = previous_cohort.id
tracker.emit( tracker.emit(
"edx.cohort.user_add_requested", "edx.cohort.user_add_requested",
...@@ -369,12 +358,11 @@ def add_user_to_cohort(cohort, username_or_email): ...@@ -369,12 +358,11 @@ def add_user_to_cohort(cohort, username_or_email):
"user_id": user.id, "user_id": user.id,
"cohort_id": cohort.id, "cohort_id": cohort.id,
"cohort_name": cohort.name, "cohort_name": cohort.name,
"previous_cohort_id": previous_cohort_id, "previous_cohort_id": membership.previous_cohort_id,
"previous_cohort_name": previous_cohort_name, "previous_cohort_name": membership.previous_cohort_name,
} }
) )
cohort.users.add(user) return (user, membership.previous_cohort_name)
return (user, previous_cohort_name)
def get_group_info_for_cohort(cohort, use_cached=False): def get_group_info_for_cohort(cohort, use_cached=False):
......
...@@ -6,7 +6,8 @@ import json ...@@ -6,7 +6,8 @@ import json
import logging import logging
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import models from django.db import models, transaction, IntegrityError
from django.core.exceptions import ValidationError
from xmodule_django.models import CourseKeyField from xmodule_django.models import CourseKeyField
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -58,6 +59,97 @@ class CourseUserGroup(models.Model): ...@@ -58,6 +59,97 @@ class CourseUserGroup(models.Model):
) )
class CohortMembership(models.Model):
"""Used internally to enforce our particular definition of uniqueness"""
course_user_group = models.ForeignKey(CourseUserGroup)
user = models.ForeignKey(User)
course_id = CourseKeyField(max_length=255)
previous_cohort = None
previous_cohort_name = None
previous_cohort_id = None
class Meta(object):
unique_together = (('user', 'course_id'), )
# The sole purpose of overriding this method is to get the django 1.6 behavior of allowing 'validate_unique'
# For django 1.8 upgrade, just remove this method and allow the base method to be called instead.
# Reference: https://docs.djangoproject.com/en/1.6/ref/models/instances/, under "Validating Objects"
def full_clean(self, **kwargs):
self.clean_fields()
self.clean()
if 'validate_unique' not in kwargs or kwargs['validate_unique'] is True:
self.validate_unique()
def clean_fields(self, *args, **kwargs):
if self.course_id is None:
self.course_id = self.course_user_group.course_id
super(CohortMembership, self).clean_fields(*args, **kwargs)
def clean(self):
if self.course_user_group.group_type != CourseUserGroup.COHORT: # pylint: disable=E1101
raise ValidationError("CohortMembership cannot be used with CourseGroup types other than COHORT")
if self.course_user_group.course_id != self.course_id:
raise ValidationError("Non-matching course_ids provided")
def save(self, *args, **kwargs):
# Avoid infinite recursion if creating from get_or_create() call below.
if 'force_insert' in kwargs and kwargs['force_insert'] is True:
super(CohortMembership, self).save(*args, **kwargs)
return
self.full_clean(validate_unique=False)
# This loop has been created to allow for optimistic locking, and retrial in case of losing a race condition.
# The limit is 2, since select_for_update ensures atomic updates. Creation is the only possible race condition.
max_retries = 2
success = False
for __ in range(max_retries):
# The following 2 "transaction" lines force a fresh read, they can be removed once we're on django 1.8
# http://stackoverflow.com/questions/3346124/how-do-i-force-django-to-ignore-any-caches-and-reload-data
with transaction.commit_manually():
transaction.commit()
with transaction.commit_on_success():
try:
saved_membership, created = CohortMembership.objects.select_for_update().get_or_create(
user__id=self.user.id, # pylint: disable=E1101
course_id=self.course_id,
defaults={
'course_user_group': self.course_user_group,
'user': self.user
}
)
except IntegrityError: # This can happen if simultaneous requests try to create a membership
transaction.rollback()
continue
if not created:
if saved_membership.course_user_group == self.course_user_group:
raise ValueError("User {user_name} already present in cohort {cohort_name}".format(
user_name=self.user.username, # pylint: disable=E1101
cohort_name=self.course_user_group.name
))
self.previous_cohort = saved_membership.course_user_group
self.previous_cohort_name = saved_membership.course_user_group.name
self.previous_cohort_id = saved_membership.course_user_group.id
self.previous_cohort.users.remove(self.user)
saved_membership.course_user_group = self.course_user_group
self.course_user_group.users.add(self.user) # pylint: disable=E1101
#note: in django 1.8, we can call save with updated_fields=['course_user_group']
super(CohortMembership, saved_membership).save()
success = True
break
if not success:
raise IntegrityError("Unable to save membership after {} tries, aborting.".format(max_retries))
class CourseUserGroupPartitionGroup(models.Model): class CourseUserGroupPartitionGroup(models.Model):
""" """
Create User Partition Info. Create User Partition Info.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment