Commit 055d378a by Chris Dodge Committed by Jonathan Piacenti

Use a simplier serializer and do the course counts as a single aggregate query

parent 3125c9c7
...@@ -42,6 +42,15 @@ class UserSerializer(DynamicFieldsModelSerializer): ...@@ -42,6 +42,15 @@ class UserSerializer(DynamicFieldsModelSerializer):
fields = ("id", "email", "username", "first_name", "last_name", "created", "is_active", "organizations", "avatar_url", "city", "title", "country", "full_name") fields = ("id", "email", "username", "first_name", "last_name", "created", "is_active", "organizations", "avatar_url", "city", "title", "country", "full_name")
read_only_fields = ("id", "email", "username") read_only_fields = ("id", "email", "username")
class SimpleUserSerializer(DynamicFieldsModelSerializer):
created = serializers.DateTimeField(source='date_joined', required=False)
class Meta:
""" Serializer/field specification """
model = APIUser
fields = ("id", "email", "username", "first_name", "last_name", "created", "is_active")
read_only_fields = ("id", "email", "username")
class UserCountByCitySerializer(serializers.Serializer): class UserCountByCitySerializer(serializers.Serializer):
""" Serializer for user count by city """ """ Serializer for user count by city """
city = serializers.CharField(source='profile__city') city = serializers.CharField(source='profile__city')
......
...@@ -62,6 +62,14 @@ class OrganizationsApiTests(ModuleStoreTestCase): ...@@ -62,6 +62,14 @@ class OrganizationsApiTests(ModuleStoreTestCase):
profile.city = 'Boston' profile.city = 'Boston'
profile.save() profile.save()
self.test_user2 = User.objects.create(
email=str(uuid.uuid4()),
username=str(uuid.uuid4())
)
profile2 = UserProfile(user=self.test_user2)
profile2.city = 'NYC'
profile2.save()
self.course = CourseFactory.create() self.course = CourseFactory.create()
self.second_course = CourseFactory.create( self.second_course = CourseFactory.create(
number="899" number="899"
...@@ -322,6 +330,7 @@ class OrganizationsApiTests(ModuleStoreTestCase): ...@@ -322,6 +330,7 @@ class OrganizationsApiTests(ModuleStoreTestCase):
def test_organizations_users_get_with_course_count(self): def test_organizations_users_get_with_course_count(self):
CourseEnrollmentFactory.create(user=self.test_user, course_id=self.course.id) CourseEnrollmentFactory.create(user=self.test_user, course_id=self.course.id)
CourseEnrollmentFactory.create(user=self.test_user2, course_id=self.course.id)
CourseEnrollmentFactory.create(user=self.test_user, course_id=self.second_course.id) CourseEnrollmentFactory.create(user=self.test_user, course_id=self.second_course.id)
data = { data = {
...@@ -338,10 +347,16 @@ class OrganizationsApiTests(ModuleStoreTestCase): ...@@ -338,10 +347,16 @@ class OrganizationsApiTests(ModuleStoreTestCase):
data = {"id": self.test_user.id} data = {"id": self.test_user.id}
response = self.do_post(users_uri, data) response = self.do_post(users_uri, data)
self.assertEqual(response.status_code, 201) self.assertEqual(response.status_code, 201)
data = {"id": self.test_user2.id}
response = self.do_post(users_uri, data)
self.assertEqual(response.status_code, 201)
response = self.do_get('{}{}'.format(users_uri, '?include_course_counts=True')) response = self.do_get('{}{}'.format(users_uri, '?include_course_counts=True'))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data[0]['id'], self.test_user.id) self.assertEqual(response.data[0]['id'], self.test_user.id)
self.assertEqual(response.data[0]['course_count'], 2) self.assertEqual(response.data[0]['course_count'], 2)
self.assertEqual(response.data[1]['id'], self.test_user2.id)
self.assertEqual(response.data[1]['course_count'], 1)
def test_organizations_users_get_with_grades(self): def test_organizations_users_get_with_grades(self):
# Create 4 users # Create 4 users
......
...@@ -12,7 +12,7 @@ from rest_framework.response import Response ...@@ -12,7 +12,7 @@ from rest_framework.response import Response
from api_manager.courseware_access import get_course_key, get_aggregate_exclusion_user_ids from api_manager.courseware_access import get_course_key, get_aggregate_exclusion_user_ids
from organizations.models import Organization from organizations.models import Organization
from api_manager.users.serializers import UserSerializer from api_manager.users.serializers import UserSerializer, SimpleUserSerializer
from api_manager.groups.serializers import GroupSerializer from api_manager.groups.serializers import GroupSerializer
from api_manager.utils import str2bool from api_manager.utils import str2bool
from gradebook.models import StudentGradebook from gradebook.models import StudentGradebook
...@@ -92,19 +92,28 @@ class OrganizationsViewSet(viewsets.ModelViewSet): ...@@ -92,19 +92,28 @@ class OrganizationsViewSet(viewsets.ModelViewSet):
course_key = get_course_key(course_id) course_key = get_course_key(course_id)
users = User.objects.filter(organizations=pk) users = User.objects.filter(organizations=pk)
if course_key: if course_key:
users = users.filter(courseenrollment__course_id__exact=course_key, users = users.filter(courseenrollment__course_id__exact=course_key,
courseenrollment__is_active=True) courseenrollment__is_active=True)
if str2bool(include_grades): if str2bool(include_grades):
users = users.select_related('studentgradebook') users = users.select_related('studentgradebook')
if str2bool(include_course_counts):
enrollments = CourseEnrollment.objects.filter(user__in=users).values('user').order_by().annotate(total=Count('user'))
enrollments_by_user = {}
for enrollment in enrollments:
enrollments_by_user[enrollment['user']] = enrollment['total']
response_data = [] response_data = []
if users: if users:
for user in users: for user in users:
serializer = UserSerializer(user) serializer = SimpleUserSerializer(user)
user_data = serializer.data user_data = serializer.data
if str2bool(include_course_counts): if str2bool(include_course_counts):
enrollments = CourseEnrollment.enrollments_for_user(user).count() user_data['course_count'] = enrollments_by_user.get(user.id, 0)
user_data['course_count'] = enrollments
if str2bool(include_grades) and course_key: if str2bool(include_grades) and course_key:
user_grades = {'grade': 0, 'proforma_grade': 0} user_grades = {'grade': 0, 'proforma_grade': 0}
gradebook = user.studentgradebook_set.filter(course_id=course_key) gradebook = user.studentgradebook_set.filter(course_id=course_key)
...@@ -114,6 +123,7 @@ class OrganizationsViewSet(viewsets.ModelViewSet): ...@@ -114,6 +123,7 @@ class OrganizationsViewSet(viewsets.ModelViewSet):
user_grades['complete_status'] = True if 0 < gradebook[0].proforma_grade <= \ user_grades['complete_status'] = True if 0 < gradebook[0].proforma_grade <= \
gradebook[0].grade + grade_complete_match_range else False gradebook[0].grade + grade_complete_match_range else False
user_data.update(user_grades) user_data.update(user_grades)
response_data.append(user_data) response_data.append(user_data)
return Response(response_data, status=status.HTTP_200_OK) return Response(response_data, status=status.HTTP_200_OK)
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment