Commit 88a38965 by Eric Fischer

CohortMembership Test Fixes

There are 3 main changes in this commit:
* CohortFactory now sets up memberships properly, so consuming tests do not
need to explicitly touch CourseUserGroup.users to add() users.
* test_get_cohort_sql_queries has been updated to 3 and 9 queries when using
and not using the cache, respectively. This is needed due to each operation
needing an extra queery to get the CourseUserGroup from the CohortMembership.
* Adding remove_user_from_cohort(), the counterpart to add_user_to_cohort().
This is also to keep tests from touching the users field directly, and keep
CohortMembership data in sync.
parent 731d85f7
......@@ -34,19 +34,19 @@ class CohortedTestCase(SharedModuleStoreTestCase):
def setUp(self):
super(CohortedTestCase, self).setUp()
self.student_cohort = CohortFactory.create(
name="student_cohort",
course_id=self.course.id
)
self.moderator_cohort = CohortFactory.create(
name="moderator_cohort",
course_id=self.course.id
)
seed_permissions_roles(self.course.id)
self.student = UserFactory.create()
self.moderator = UserFactory.create()
CourseEnrollmentFactory(user=self.student, course_id=self.course.id)
CourseEnrollmentFactory(user=self.moderator, course_id=self.course.id)
self.moderator.roles.add(Role.objects.get(name="Moderator", course_id=self.course.id))
self.student_cohort.users.add(self.student)
self.moderator_cohort.users.add(self.moderator)
self.student_cohort = CohortFactory.create(
name="student_cohort",
course_id=self.course.id,
users=[self.student]
)
self.moderator_cohort = CohortFactory.create(
name="moderator_cohort",
course_id=self.course.id,
users=[self.moderator]
)
......@@ -133,8 +133,8 @@ class TestAnalyticsBasic(ModuleStoreTestCase):
course = CourseFactory.create(org="test", course="course1", display_name="run1")
course.cohort_config = {'cohorted': True, 'auto_cohort': True, 'auto_cohort_groups': ['cohort']}
self.store.update_item(course, self.instructor.id)
cohort = CohortFactory.create(name='cohort', course_id=course.id)
cohorted_students = [UserFactory.create() for _ in xrange(10)]
cohort = CohortFactory.create(name='cohort', course_id=course.id, users=cohorted_students)
cohorted_usernames = [student.username for student in cohorted_students]
non_cohorted_student = UserFactory.create()
for student in cohorted_students:
......
......@@ -17,6 +17,7 @@ from xmodule.partitions.partitions import Group, UserPartition
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup
from openedx.core.djangoapps.course_groups.cohorts import add_user_to_cohort, remove_user_from_cohort
from ..testutils import MobileAPITestCase, MobileAuthTestMixin, MobileCourseAccessTestMixin
......@@ -744,7 +745,7 @@ class TestVideoSummaryList(
for cohort_index in range(len(cohorts)):
# add user to this cohort
cohorts[cohort_index].users.add(self.user)
add_user_to_cohort(cohorts[cohort_index], self.user.username)
# should only see video for this cohort
video_outline = self.api_response().data
......@@ -755,7 +756,7 @@ class TestVideoSummaryList(
)
# remove user from this cohort
cohorts[cohort_index].users.remove(self.user)
remove_user_from_cohort(cohorts[cohort_index], self.user.username)
# un-cohorted user should see no videos
video_outline = self.api_response().data
......
......@@ -5,7 +5,7 @@ from django.conf import settings
from django.test.client import RequestFactory
from django.test.utils import override_settings
from openedx.core.djangoapps.course_groups.models import CourseUserGroup
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from django_comment_common.models import Role, Permission
from lang_pref import LANGUAGE_KEY
from notification_prefs import NOTIFICATION_PREF_KEY
......@@ -46,12 +46,11 @@ class NotifierUsersViewSetTest(UrlResetMixin, ModuleStoreTestCase):
self.courses.append(course)
CourseEnrollmentFactory(user=self.user, course_id=course.id)
if is_user_cohorted:
cohort = CourseUserGroup.objects.create(
cohort = CohortFactory.create(
name="Test Cohort",
course_id=course.id,
group_type=CourseUserGroup.COHORT
users=[self.user]
)
cohort.users.add(self.user)
self.cohorts.append(cohort)
if is_moderator:
moderator_perm, _ = Permission.objects.get_or_create(name="see_all_cohorts")
......
......@@ -179,6 +179,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
if not course_cohort_settings.is_cohorted:
return request_cache.data.setdefault(cache_key, None)
# If course is cohorted, check if the user already has a cohort.
try:
membership = CohortMembership.objects.get(
course_id=course_key,
......@@ -192,6 +193,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
# may be True, and we will have to assign the user a cohort.
return None
# Otherwise assign the user a cohort.
membership = CohortMembership.objects.create(
user=user,
course_user_group=_get_default_cohort(course_key)
......@@ -336,6 +338,27 @@ def is_cohort_exists(course_key, name):
return CourseUserGroup.objects.filter(course_id=course_key, group_type=CourseUserGroup.COHORT, name=name).exists()
def remove_user_from_cohort(cohort, username_or_email):
"""
Look up the given user, and if successful, remove them from the specified cohort.
Arguments:
cohort: CourseUserGroup
username_or_email: string. Treated as email if has '@'
Raises:
User.DoesNotExist if can't find user.
ValueError if user not already present in this cohort.
"""
user = get_user_by_username_or_email(username_or_email)
try:
membership = CohortMembership.objects.get(course_user_group=cohort, user=user)
membership.delete()
except CohortMembership.DoesNotExist:
raise ValueError("User {} was not present in cohort {}".format(username_or_email, cohort))
def add_user_to_cohort(cohort, username_or_email):
"""
Look up the given user, and if successful, add them to the specified cohort.
......
......@@ -32,6 +32,11 @@ class CohortFactory(DjangoModelFactory):
"""
if extracted:
self.users.add(*extracted)
for user in self.users.all():
CohortMembership.objects.create(
user=user,
course_user_group=self,
)
class CourseCohortFactory(DjangoModelFactory):
......@@ -41,18 +46,6 @@ class CourseCohortFactory(DjangoModelFactory):
class Meta(object):
model = CourseCohort
@post_generation
def memberships(self, create, extracted, **kwargs): # pylint: disable=unused-argument
"""
Returns the memberships linking users to this cohort.
"""
for user in self.course_user_group.users.all(): # pylint: disable=E1101
membership = CohortMembership(user=user, course_user_group=self.course_user_group)
membership.save()
course_user_group = factory.SubFactory(CohortFactory)
assignment_type = 'manual'
class CourseCohortSettingsFactory(DjangoModelFactory):
"""
......
......@@ -179,8 +179,7 @@ class TestCohorts(ModuleStoreTestCase):
self.assertIsNone(cohorts.get_cohort_id(user, course.id))
config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort")
cohort.users.add(user)
cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user])
self.assertEqual(cohorts.get_cohort_id(user, course.id), cohort.id)
self.assertRaises(
......@@ -237,8 +236,7 @@ class TestCohorts(ModuleStoreTestCase):
self.assertIsNone(cohorts.get_cohort(user, course.id), "No cohort created yet")
cohort = CohortFactory(course_id=course.id, name="TestCohort")
cohort.users.add(user)
cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user])
self.assertIsNone(
cohorts.get_cohort(user, course.id),
......@@ -261,8 +259,8 @@ class TestCohorts(ModuleStoreTestCase):
)
@ddt.data(
(True, 2),
(False, 6),
(True, 3),
(False, 9),
)
@ddt.unpack
def test_get_cohort_sql_queries(self, use_cached, num_sql_queries):
......@@ -271,10 +269,8 @@ class TestCohorts(ModuleStoreTestCase):
"""
course = modulestore().get_course(self.toy_course_key)
config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort")
user = UserFactory(username="test", email="a@b.com")
cohort.users.add(user)
CohortFactory.create(course_id=course.id, name="TestCohort", users=[user])
with self.assertNumQueries(num_sql_queries):
for __ in range(3):
......@@ -314,10 +310,7 @@ class TestCohorts(ModuleStoreTestCase):
user1 = UserFactory(username="test", email="a@b.com")
user2 = UserFactory(username="test2", email="a2@b.com")
cohort = CohortFactory(course_id=course.id, name="TestCohort")
# user1 manually added to a cohort
cohort.users.add(user1)
cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user1])
# Add an auto_cohort_group to the course...
config_course_cohorts(
......
......@@ -21,7 +21,7 @@ from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartiti
from ..partition_scheme import CohortPartitionScheme, get_cohorted_user_partition
from ..models import CourseUserGroupPartitionGroup
from ..views import link_cohort_to_partition_group, unlink_cohort_partition_group
from ..cohorts import add_user_to_cohort, get_course_cohorts
from ..cohorts import add_user_to_cohort, remove_user_from_cohort, get_course_cohorts
from .helpers import CohortFactory, config_course_cohorts
......@@ -100,7 +100,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase):
self.assert_student_in_group(self.groups[1])
# move the student out of the cohort
second_cohort.users.remove(self.student)
remove_user_from_cohort(second_cohort, self.student.username)
self.assert_student_in_group(None)
def test_cohort_partition_group_assignment(self):
......
......@@ -10,7 +10,6 @@ from django.core.urlresolvers import reverse
from django.http import Http404, HttpResponseBadRequest
from django.views.decorators.http import require_http_methods
from util.json_request import expect_json, JsonResponse
from util.db import outer_atomic
from django.db import transaction
from django.contrib.auth.decorators import login_required
from django.utils.translation import ugettext
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment