Commit 3488083e by Victor Shnayder Committed by Victor Shnayder

Add a get_course_cohorts function and test

parent 90a57034
...@@ -40,9 +40,9 @@ def get_cohort(user, course_id): ...@@ -40,9 +40,9 @@ def get_cohort(user, course_id):
Returns: Returns:
A CourseUserGroup object if the User has a cohort, or None. A CourseUserGroup object if the User has a cohort, or None.
""" """
group_type = CourseUserGroup.COHORT
try: try:
group = CourseUserGroup.objects.get(course_id=course_id, group_type=group_type, group = CourseUserGroup.objects.get(course_id=course_id,
group_type=CourseUserGroup.COHORT,
users__id=user.id) users__id=user.id)
except CourseUserGroup.DoesNotExist: except CourseUserGroup.DoesNotExist:
group = None group = None
...@@ -52,3 +52,16 @@ def get_cohort(user, course_id): ...@@ -52,3 +52,16 @@ def get_cohort(user, course_id):
# TODO: add auto-cohorting logic here # TODO: add auto-cohorting logic here
return None return None
def get_course_cohorts(course_id):
"""
Get a list of all the cohorts in the given course.
Arguments:
course_id: string in the format 'org/course/run'
Returns:
A list of CourseUserGroup objects. Empty if there are no cohorts.
"""
return list(CourseUserGroup.objects.filter(course_id=course_id,
group_type=CourseUserGroup.COHORT))
import django.test
from django.contrib.auth.models import User from django.contrib.auth.models import User
from nose.tools import assert_equals
from course_groups.models import CourseUserGroup, get_cohort from course_groups.models import CourseUserGroup, get_cohort, get_course_cohorts
def test_get_cohort(): class TestCohorts(django.test.TestCase):
def test_get_cohort(self):
course_id = "a/b/c" course_id = "a/b/c"
cohort = CourseUserGroup.objects.create(name="TestCohort", course_id=course_id, cohort = CourseUserGroup.objects.create(name="TestCohort", course_id=course_id,
group_type=CourseUserGroup.COHORT) group_type=CourseUserGroup.COHORT)
...@@ -14,8 +16,29 @@ def test_get_cohort(): ...@@ -14,8 +16,29 @@ def test_get_cohort():
cohort.users.add(user) cohort.users.add(user)
got = get_cohort(user, course_id) got = get_cohort(user, course_id)
assert_equals(got.id, cohort.id, "Should find the right cohort") self.assertEquals(got.id, cohort.id, "Should find the right cohort")
got = get_cohort(other_user, course_id) got = get_cohort(other_user, course_id)
assert_equals(got, None, "other_user shouldn't have a cohort") self.assertEquals(got, None, "other_user shouldn't have a cohort")
def test_get_course_cohorts(self):
course1_id = "a/b/c"
course2_id = "e/f/g"
# add some cohorts to course 1
cohort = CourseUserGroup.objects.create(name="TestCohort",
course_id=course1_id,
group_type=CourseUserGroup.COHORT)
cohort = CourseUserGroup.objects.create(name="TestCohort2",
course_id=course1_id,
group_type=CourseUserGroup.COHORT)
# second course should have no cohorts
self.assertEqual(get_course_cohorts(course2_id), [])
cohorts = sorted([c.name for c in get_course_cohorts(course1_id)])
self.assertEqual(cohorts, ['TestCohort', 'TestCohort2'])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment