Commit 88a38965 by Eric Fischer

CohortMembership Test Fixes

There are 3 main changes in this commit:
* CohortFactory now sets up memberships properly, so consuming tests do not
need to explicitly touch CourseUserGroup.users to add() users.
* test_get_cohort_sql_queries has been updated to 3 and 9 queries when using
and not using the cache, respectively. This is needed due to each operation
needing an extra queery to get the CourseUserGroup from the CohortMembership.
* Adding remove_user_from_cohort(), the counterpart to add_user_to_cohort().
This is also to keep tests from touching the users field directly, and keep
CohortMembership data in sync.
parent 731d85f7
...@@ -34,19 +34,19 @@ class CohortedTestCase(SharedModuleStoreTestCase): ...@@ -34,19 +34,19 @@ class CohortedTestCase(SharedModuleStoreTestCase):
def setUp(self): def setUp(self):
super(CohortedTestCase, self).setUp() super(CohortedTestCase, self).setUp()
self.student_cohort = CohortFactory.create(
name="student_cohort",
course_id=self.course.id
)
self.moderator_cohort = CohortFactory.create(
name="moderator_cohort",
course_id=self.course.id
)
seed_permissions_roles(self.course.id) seed_permissions_roles(self.course.id)
self.student = UserFactory.create() self.student = UserFactory.create()
self.moderator = UserFactory.create() self.moderator = UserFactory.create()
CourseEnrollmentFactory(user=self.student, course_id=self.course.id) CourseEnrollmentFactory(user=self.student, course_id=self.course.id)
CourseEnrollmentFactory(user=self.moderator, course_id=self.course.id) CourseEnrollmentFactory(user=self.moderator, course_id=self.course.id)
self.moderator.roles.add(Role.objects.get(name="Moderator", course_id=self.course.id)) self.moderator.roles.add(Role.objects.get(name="Moderator", course_id=self.course.id))
self.student_cohort.users.add(self.student) self.student_cohort = CohortFactory.create(
self.moderator_cohort.users.add(self.moderator) name="student_cohort",
course_id=self.course.id,
users=[self.student]
)
self.moderator_cohort = CohortFactory.create(
name="moderator_cohort",
course_id=self.course.id,
users=[self.moderator]
)
...@@ -133,8 +133,8 @@ class TestAnalyticsBasic(ModuleStoreTestCase): ...@@ -133,8 +133,8 @@ class TestAnalyticsBasic(ModuleStoreTestCase):
course = CourseFactory.create(org="test", course="course1", display_name="run1") course = CourseFactory.create(org="test", course="course1", display_name="run1")
course.cohort_config = {'cohorted': True, 'auto_cohort': True, 'auto_cohort_groups': ['cohort']} course.cohort_config = {'cohorted': True, 'auto_cohort': True, 'auto_cohort_groups': ['cohort']}
self.store.update_item(course, self.instructor.id) self.store.update_item(course, self.instructor.id)
cohort = CohortFactory.create(name='cohort', course_id=course.id)
cohorted_students = [UserFactory.create() for _ in xrange(10)] cohorted_students = [UserFactory.create() for _ in xrange(10)]
cohort = CohortFactory.create(name='cohort', course_id=course.id, users=cohorted_students)
cohorted_usernames = [student.username for student in cohorted_students] cohorted_usernames = [student.username for student in cohorted_students]
non_cohorted_student = UserFactory.create() non_cohorted_student = UserFactory.create()
for student in cohorted_students: for student in cohorted_students:
......
...@@ -17,6 +17,7 @@ from xmodule.partitions.partitions import Group, UserPartition ...@@ -17,6 +17,7 @@ from xmodule.partitions.partitions import Group, UserPartition
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup
from openedx.core.djangoapps.course_groups.cohorts import add_user_to_cohort, remove_user_from_cohort
from ..testutils import MobileAPITestCase, MobileAuthTestMixin, MobileCourseAccessTestMixin from ..testutils import MobileAPITestCase, MobileAuthTestMixin, MobileCourseAccessTestMixin
...@@ -744,7 +745,7 @@ class TestVideoSummaryList( ...@@ -744,7 +745,7 @@ class TestVideoSummaryList(
for cohort_index in range(len(cohorts)): for cohort_index in range(len(cohorts)):
# add user to this cohort # add user to this cohort
cohorts[cohort_index].users.add(self.user) add_user_to_cohort(cohorts[cohort_index], self.user.username)
# should only see video for this cohort # should only see video for this cohort
video_outline = self.api_response().data video_outline = self.api_response().data
...@@ -755,7 +756,7 @@ class TestVideoSummaryList( ...@@ -755,7 +756,7 @@ class TestVideoSummaryList(
) )
# remove user from this cohort # remove user from this cohort
cohorts[cohort_index].users.remove(self.user) remove_user_from_cohort(cohorts[cohort_index], self.user.username)
# un-cohorted user should see no videos # un-cohorted user should see no videos
video_outline = self.api_response().data video_outline = self.api_response().data
......
...@@ -5,7 +5,7 @@ from django.conf import settings ...@@ -5,7 +5,7 @@ from django.conf import settings
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.test.utils import override_settings from django.test.utils import override_settings
from openedx.core.djangoapps.course_groups.models import CourseUserGroup from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from django_comment_common.models import Role, Permission from django_comment_common.models import Role, Permission
from lang_pref import LANGUAGE_KEY from lang_pref import LANGUAGE_KEY
from notification_prefs import NOTIFICATION_PREF_KEY from notification_prefs import NOTIFICATION_PREF_KEY
...@@ -46,12 +46,11 @@ class NotifierUsersViewSetTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -46,12 +46,11 @@ class NotifierUsersViewSetTest(UrlResetMixin, ModuleStoreTestCase):
self.courses.append(course) self.courses.append(course)
CourseEnrollmentFactory(user=self.user, course_id=course.id) CourseEnrollmentFactory(user=self.user, course_id=course.id)
if is_user_cohorted: if is_user_cohorted:
cohort = CourseUserGroup.objects.create( cohort = CohortFactory.create(
name="Test Cohort", name="Test Cohort",
course_id=course.id, course_id=course.id,
group_type=CourseUserGroup.COHORT users=[self.user]
) )
cohort.users.add(self.user)
self.cohorts.append(cohort) self.cohorts.append(cohort)
if is_moderator: if is_moderator:
moderator_perm, _ = Permission.objects.get_or_create(name="see_all_cohorts") moderator_perm, _ = Permission.objects.get_or_create(name="see_all_cohorts")
......
...@@ -179,6 +179,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False): ...@@ -179,6 +179,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
if not course_cohort_settings.is_cohorted: if not course_cohort_settings.is_cohorted:
return request_cache.data.setdefault(cache_key, None) return request_cache.data.setdefault(cache_key, None)
# If course is cohorted, check if the user already has a cohort.
try: try:
membership = CohortMembership.objects.get( membership = CohortMembership.objects.get(
course_id=course_key, course_id=course_key,
...@@ -192,6 +193,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False): ...@@ -192,6 +193,7 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
# may be True, and we will have to assign the user a cohort. # may be True, and we will have to assign the user a cohort.
return None return None
# Otherwise assign the user a cohort.
membership = CohortMembership.objects.create( membership = CohortMembership.objects.create(
user=user, user=user,
course_user_group=_get_default_cohort(course_key) course_user_group=_get_default_cohort(course_key)
...@@ -336,6 +338,27 @@ def is_cohort_exists(course_key, name): ...@@ -336,6 +338,27 @@ def is_cohort_exists(course_key, name):
return CourseUserGroup.objects.filter(course_id=course_key, group_type=CourseUserGroup.COHORT, name=name).exists() return CourseUserGroup.objects.filter(course_id=course_key, group_type=CourseUserGroup.COHORT, name=name).exists()
def remove_user_from_cohort(cohort, username_or_email):
"""
Look up the given user, and if successful, remove them from the specified cohort.
Arguments:
cohort: CourseUserGroup
username_or_email: string. Treated as email if has '@'
Raises:
User.DoesNotExist if can't find user.
ValueError if user not already present in this cohort.
"""
user = get_user_by_username_or_email(username_or_email)
try:
membership = CohortMembership.objects.get(course_user_group=cohort, user=user)
membership.delete()
except CohortMembership.DoesNotExist:
raise ValueError("User {} was not present in cohort {}".format(username_or_email, cohort))
def add_user_to_cohort(cohort, username_or_email): def add_user_to_cohort(cohort, username_or_email):
""" """
Look up the given user, and if successful, add them to the specified cohort. Look up the given user, and if successful, add them to the specified cohort.
......
...@@ -32,6 +32,11 @@ class CohortFactory(DjangoModelFactory): ...@@ -32,6 +32,11 @@ class CohortFactory(DjangoModelFactory):
""" """
if extracted: if extracted:
self.users.add(*extracted) self.users.add(*extracted)
for user in self.users.all():
CohortMembership.objects.create(
user=user,
course_user_group=self,
)
class CourseCohortFactory(DjangoModelFactory): class CourseCohortFactory(DjangoModelFactory):
...@@ -41,18 +46,6 @@ class CourseCohortFactory(DjangoModelFactory): ...@@ -41,18 +46,6 @@ class CourseCohortFactory(DjangoModelFactory):
class Meta(object): class Meta(object):
model = CourseCohort model = CourseCohort
@post_generation
def memberships(self, create, extracted, **kwargs): # pylint: disable=unused-argument
"""
Returns the memberships linking users to this cohort.
"""
for user in self.course_user_group.users.all(): # pylint: disable=E1101
membership = CohortMembership(user=user, course_user_group=self.course_user_group)
membership.save()
course_user_group = factory.SubFactory(CohortFactory)
assignment_type = 'manual'
class CourseCohortSettingsFactory(DjangoModelFactory): class CourseCohortSettingsFactory(DjangoModelFactory):
""" """
......
...@@ -179,8 +179,7 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -179,8 +179,7 @@ class TestCohorts(ModuleStoreTestCase):
self.assertIsNone(cohorts.get_cohort_id(user, course.id)) self.assertIsNone(cohorts.get_cohort_id(user, course.id))
config_course_cohorts(course, is_cohorted=True) config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort") cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user])
cohort.users.add(user)
self.assertEqual(cohorts.get_cohort_id(user, course.id), cohort.id) self.assertEqual(cohorts.get_cohort_id(user, course.id), cohort.id)
self.assertRaises( self.assertRaises(
...@@ -237,8 +236,7 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -237,8 +236,7 @@ class TestCohorts(ModuleStoreTestCase):
self.assertIsNone(cohorts.get_cohort(user, course.id), "No cohort created yet") self.assertIsNone(cohorts.get_cohort(user, course.id), "No cohort created yet")
cohort = CohortFactory(course_id=course.id, name="TestCohort") cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user])
cohort.users.add(user)
self.assertIsNone( self.assertIsNone(
cohorts.get_cohort(user, course.id), cohorts.get_cohort(user, course.id),
...@@ -261,8 +259,8 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -261,8 +259,8 @@ class TestCohorts(ModuleStoreTestCase):
) )
@ddt.data( @ddt.data(
(True, 2), (True, 3),
(False, 6), (False, 9),
) )
@ddt.unpack @ddt.unpack
def test_get_cohort_sql_queries(self, use_cached, num_sql_queries): def test_get_cohort_sql_queries(self, use_cached, num_sql_queries):
...@@ -271,10 +269,8 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -271,10 +269,8 @@ class TestCohorts(ModuleStoreTestCase):
""" """
course = modulestore().get_course(self.toy_course_key) course = modulestore().get_course(self.toy_course_key)
config_course_cohorts(course, is_cohorted=True) config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort")
user = UserFactory(username="test", email="a@b.com") user = UserFactory(username="test", email="a@b.com")
cohort.users.add(user) CohortFactory.create(course_id=course.id, name="TestCohort", users=[user])
with self.assertNumQueries(num_sql_queries): with self.assertNumQueries(num_sql_queries):
for __ in range(3): for __ in range(3):
...@@ -314,10 +310,7 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -314,10 +310,7 @@ class TestCohorts(ModuleStoreTestCase):
user1 = UserFactory(username="test", email="a@b.com") user1 = UserFactory(username="test", email="a@b.com")
user2 = UserFactory(username="test2", email="a2@b.com") user2 = UserFactory(username="test2", email="a2@b.com")
cohort = CohortFactory(course_id=course.id, name="TestCohort") cohort = CohortFactory(course_id=course.id, name="TestCohort", users=[user1])
# user1 manually added to a cohort
cohort.users.add(user1)
# Add an auto_cohort_group to the course... # Add an auto_cohort_group to the course...
config_course_cohorts( config_course_cohorts(
......
...@@ -21,7 +21,7 @@ from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartiti ...@@ -21,7 +21,7 @@ from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartiti
from ..partition_scheme import CohortPartitionScheme, get_cohorted_user_partition from ..partition_scheme import CohortPartitionScheme, get_cohorted_user_partition
from ..models import CourseUserGroupPartitionGroup from ..models import CourseUserGroupPartitionGroup
from ..views import link_cohort_to_partition_group, unlink_cohort_partition_group from ..views import link_cohort_to_partition_group, unlink_cohort_partition_group
from ..cohorts import add_user_to_cohort, get_course_cohorts from ..cohorts import add_user_to_cohort, remove_user_from_cohort, get_course_cohorts
from .helpers import CohortFactory, config_course_cohorts from .helpers import CohortFactory, config_course_cohorts
...@@ -100,7 +100,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase): ...@@ -100,7 +100,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase):
self.assert_student_in_group(self.groups[1]) self.assert_student_in_group(self.groups[1])
# move the student out of the cohort # move the student out of the cohort
second_cohort.users.remove(self.student) remove_user_from_cohort(second_cohort, self.student.username)
self.assert_student_in_group(None) self.assert_student_in_group(None)
def test_cohort_partition_group_assignment(self): def test_cohort_partition_group_assignment(self):
......
...@@ -10,7 +10,6 @@ from django.core.urlresolvers import reverse ...@@ -10,7 +10,6 @@ from django.core.urlresolvers import reverse
from django.http import Http404, HttpResponseBadRequest from django.http import Http404, HttpResponseBadRequest
from django.views.decorators.http import require_http_methods from django.views.decorators.http import require_http_methods
from util.json_request import expect_json, JsonResponse from util.json_request import expect_json, JsonResponse
from util.db import outer_atomic
from django.db import transaction from django.db import transaction
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.utils.translation import ugettext from django.utils.translation import ugettext
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment