Commit 05019e9f by Zia Fazal Committed by Xavier Antoviaque

added course_count

parent 944c8020
...@@ -12,6 +12,9 @@ from django.core.cache import cache ...@@ -12,6 +12,9 @@ from django.core.cache import cache
from django.test import TestCase, Client from django.test import TestCase, Client
from django.test.utils import override_settings from django.test.utils import override_settings
from student.tests.factories import CourseEnrollmentFactory
from xmodule.modulestore.tests.factories import CourseFactory
from courseware.tests.modulestore_config import TEST_DATA_MIXED_MODULESTORE
TEST_API_KEY = str(uuid.uuid4()) TEST_API_KEY = str(uuid.uuid4())
...@@ -25,6 +28,7 @@ class SecureClient(Client): ...@@ -25,6 +28,7 @@ class SecureClient(Client):
super(SecureClient, self).__init__(*args, **kwargs) super(SecureClient, self).__init__(*args, **kwargs)
@override_settings(MODULESTORE=TEST_DATA_MIXED_MODULESTORE)
@override_settings(EDX_API_KEY=TEST_API_KEY) @override_settings(EDX_API_KEY=TEST_API_KEY)
class OrganizationsApiTests(TestCase): class OrganizationsApiTests(TestCase):
...@@ -47,6 +51,10 @@ class OrganizationsApiTests(TestCase): ...@@ -47,6 +51,10 @@ class OrganizationsApiTests(TestCase):
email=self.test_user_email, email=self.test_user_email,
username=self.test_user_username username=self.test_user_username
) )
self.course = CourseFactory.create()
self.second_course = CourseFactory.create(
number="899"
)
self.client = SecureClient() self.client = SecureClient()
cache.clear() cache.clear()
...@@ -252,3 +260,26 @@ class OrganizationsApiTests(TestCase): ...@@ -252,3 +260,26 @@ class OrganizationsApiTests(TestCase):
self.assertEqual(response.data[0]['id'], self.test_user.id) self.assertEqual(response.data[0]['id'], self.test_user.id)
self.assertEqual(response.data[0]['username'], self.test_user.username) self.assertEqual(response.data[0]['username'], self.test_user.username)
self.assertEqual(response.data[0]['email'], self.test_user.email) self.assertEqual(response.data[0]['email'], self.test_user.email)
def test_organizations_users_get_with_course_count(self):
CourseEnrollmentFactory.create(user=self.test_user, course_id=self.course.id)
CourseEnrollmentFactory.create(user=self.test_user, course_id=self.second_course.id)
data = {
'name': self.test_organization_name,
'display_name': self.test_organization_display_name,
'contact_name': self.test_organization_contact_name,
'contact_email': self.test_organization_contact_email,
'contact_phone': self.test_organization_contact_phone
}
response = self.do_post(self.test_organizations_uri, data)
self.assertEqual(response.status_code, 201)
test_uri = '{}{}/'.format(self.test_organizations_uri, str(response.data['id']))
users_uri = '{}users/'.format(test_uri)
data = {"id": self.test_user.id}
response = self.do_post(users_uri, data)
self.assertEqual(response.status_code, 201)
response = self.do_get('{}{}'.format(users_uri, '?include_course_counts=True'))
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data[0]['id'], self.test_user.id)
self.assertEqual(response.data[0]['course_count'], 2)
...@@ -9,6 +9,8 @@ from rest_framework.decorators import action ...@@ -9,6 +9,8 @@ from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from api_manager.models import Organization from api_manager.models import Organization
from api_manager.utils import str2bool
from student.models import CourseEnrollment
from .serializers import OrganizationSerializer, UserSerializer from .serializers import OrganizationSerializer, UserSerializer
...@@ -26,12 +28,17 @@ class OrganizationsViewSet(viewsets.ModelViewSet): ...@@ -26,12 +28,17 @@ class OrganizationsViewSet(viewsets.ModelViewSet):
Add a User to an Organization Add a User to an Organization
""" """
if request.method == 'GET': if request.method == 'GET':
include_course_counts = request.QUERY_PARAMS.get('include_course_counts', None)
users = User.objects.filter(organizations=pk) users = User.objects.filter(organizations=pk)
response_data = [] response_data = []
if users: if users:
for user in users: for user in users:
serializer = UserSerializer(user) serializer = UserSerializer(user)
response_data.append(serializer.data) # pylint: disable=E1101 user_data = serializer.data
if str2bool(include_course_counts):
enrollments = CourseEnrollment.enrollments_for_user(user).count()
user_data['course_count'] = enrollments
response_data.append(user_data)
return Response(response_data, status=status.HTTP_200_OK) return Response(response_data, status=status.HTTP_200_OK)
else: else:
user_id = request.DATA.get('id') user_id = request.DATA.get('id')
......
...@@ -32,7 +32,10 @@ def str2bool(value): ...@@ -32,7 +32,10 @@ def str2bool(value):
""" """
convert string to bool convert string to bool
""" """
return value.lower() in ("true",) if value:
return value.lower() in ("true",)
else:
return False
def generate_base_uri(request, strip_qs=False): def generate_base_uri(request, strip_qs=False):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment