Commit 4b9f017d by Clinton Blackburn

Fixed bug for partial updates for Studio API course runs endpoint

Existing data remains intact when performing a partial update.

LEARNER-2468
parent d6604ebb
...@@ -112,6 +112,8 @@ class CourseRunSerializer(CourseRunTeamSerializerMixin, serializers.Serializer): ...@@ -112,6 +112,8 @@ class CourseRunSerializer(CourseRunTeamSerializerMixin, serializers.Serializer):
return instance return instance
def update_team(self, instance, team): def update_team(self, instance, team):
# Existing data should remain intact when performing a partial update.
if not self.partial:
CourseAccessRole.objects.filter(course_id=instance.id).delete() CourseAccessRole.objects.filter(course_id=instance.id).delete()
# TODO In the future we can optimize by getting users in a single query. # TODO In the future we can optimize by getting users in a single query.
......
...@@ -36,9 +36,11 @@ class CourseRunViewSetTests(ModuleStoreTestCase): ...@@ -36,9 +36,11 @@ class CourseRunViewSetTests(ModuleStoreTestCase):
assert course_run.enrollment_end == enrollment_end assert course_run.enrollment_end == enrollment_end
def assert_access_role(self, course_run, user, role): def assert_access_role(self, course_run, user, role):
# An error will be raised if the endpoint doesn't create the role # An error will be raised if the endpoint did not create the role
CourseAccessRole.objects.get(course_id=course_run.id, user=user, role=role) CourseAccessRole.objects.get(course_id=course_run.id, user=user, role=role)
assert CourseAccessRole.objects.filter(course_id=course_run.id).count() == 1
def assert_course_access_role_count(self, course_run, expected):
assert CourseAccessRole.objects.filter(course_id=course_run.id).count() == expected
def get_serializer_context(self): def get_serializer_context(self):
return {'request': RequestFactory().get('')} return {'request': RequestFactory().get('')}
...@@ -111,6 +113,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase): ...@@ -111,6 +113,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase):
response = self.client.put(url, data, format='json') response = self.client.put(url, data, format='json')
assert response.status_code == 200 assert response.status_code == 200
self.assert_access_role(course_run, user, role) self.assert_access_role(course_run, user, role)
self.assert_course_access_role_count(course_run, 1)
course_run = modulestore().get_course(course_run.id) course_run = modulestore().get_course(course_run.id)
assert response.data == CourseRunSerializer(course_run, context=self.get_serializer_context()).data assert response.data == CourseRunSerializer(course_run, context=self.get_serializer_context()).data
...@@ -134,12 +137,13 @@ class CourseRunViewSetTests(ModuleStoreTestCase): ...@@ -134,12 +137,13 @@ class CourseRunViewSetTests(ModuleStoreTestCase):
assert response.data == {'team': [{'user': ['Object with username=test-user does not exist.']}]} assert response.data == {'team': [{'user': ['Object with username=test-user does not exist.']}]}
def test_partial_update(self): def test_partial_update(self):
role = 'staff'
start = datetime.datetime.now(pytz.UTC).replace(microsecond=0) start = datetime.datetime.now(pytz.UTC).replace(microsecond=0)
course_run = CourseFactory(start=start, end=None, enrollment_start=None, enrollment_end=None) course_run = CourseFactory(start=start, end=None, enrollment_start=None, enrollment_end=None)
assert CourseAccessRole.objects.filter(course_id=course_run.id).count() == 0 CourseAccessRole.objects.create(course_id=course_run.id, role=role, user=UserFactory())
assert CourseAccessRole.objects.filter(course_id=course_run.id).count() == 1
user = UserFactory() user = UserFactory()
role = 'staff'
data = { data = {
'team': [ 'team': [
{ {
...@@ -153,6 +157,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase): ...@@ -153,6 +157,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase):
response = self.client.patch(url, data, format='json') response = self.client.patch(url, data, format='json')
assert response.status_code == 200 assert response.status_code == 200
self.assert_access_role(course_run, user, role) self.assert_access_role(course_run, user, role)
self.assert_course_access_role_count(course_run, 2)
course_run = modulestore().get_course(course_run.id) course_run = modulestore().get_course(course_run.id)
self.assert_course_run_schedule(course_run, start, None, None, None) self.assert_course_run_schedule(course_run, start, None, None, None)
...@@ -193,6 +198,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase): ...@@ -193,6 +198,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase):
assert course_run.id.run == data['run'] assert course_run.id.run == data['run']
self.assert_course_run_schedule(course_run, start, end, enrollment_start, enrollment_end) self.assert_course_run_schedule(course_run, start, end, enrollment_start, enrollment_end)
self.assert_access_role(course_run, user, role) self.assert_access_role(course_run, user, role)
self.assert_course_access_role_count(course_run, 1)
def test_images_upload(self): def test_images_upload(self):
# http://www.django-rest-framework.org/api-guide/parsers/#fileuploadparser # http://www.django-rest-framework.org/api-guide/parsers/#fileuploadparser
...@@ -260,6 +266,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase): ...@@ -260,6 +266,7 @@ class CourseRunViewSetTests(ModuleStoreTestCase):
assert course_run.id.run == run assert course_run.id.run == run
self.assert_course_run_schedule(course_run, start, end, enrollment_start, enrollment_end) self.assert_course_run_schedule(course_run, start, end, enrollment_start, enrollment_end)
self.assert_access_role(course_run, user, role) self.assert_access_role(course_run, user, role)
self.assert_course_access_role_count(course_run, 1)
def test_rerun_duplicate_run(self): def test_rerun_duplicate_run(self):
course_run = ToyCourseFactory() course_run = ToyCourseFactory()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment