Commit 2e06e592 by Calen Pennington

Use DjangoXBlockUserStateClient to implement UserStateCache

parent 257660ed
...@@ -24,6 +24,7 @@ from xblock.exceptions import KeyValueMultiSaveError, InvalidScopeError ...@@ -24,6 +24,7 @@ from xblock.exceptions import KeyValueMultiSaveError, InvalidScopeError
from xblock.fields import Scope, UserScope from xblock.fields import Scope, UserScope
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from xblock.core import XBlockAside from xblock.core import XBlockAside
from courseware.user_state_client import DjangoXBlockUserStateClient
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -204,18 +205,28 @@ class DjangoOrmFieldCache(object): ...@@ -204,18 +205,28 @@ class DjangoOrmFieldCache(object):
cache_key = self._cache_key_for_kvs_key(kvs_key) cache_key = self._cache_key_for_kvs_key(kvs_key)
field_object = self._cache.get(cache_key) field_object = self._cache.get(cache_key)
if field_object is None:
self._cache[cache_key] = field_object = self._create_object(kvs_key)
self._set_field_value(field_object, kvs_key, value)
try: try:
field_object.save() serialized_value = json.dumps(value)
saved_fields.append(kvs_key.field_name) # It is safe to force an insert or an update, because
# a) we should have retrieved the object as part of the
# prefetch step, so if it isn't in our cache, it doesn't exist yet.
# b) no other code should be modifying these models out of band of
# this cache.
if field_object is None:
field_object = self._create_object(kvs_key, serialized_value)
field_object.save(force_insert=True)
self._cache[cache_key] = field_object
else:
field_object.value = serialized_value
field_object.save(force_update=True)
except DatabaseError: except DatabaseError:
log.exception("Saving field %r failed", kvs_key.field_name) log.exception("Saving field %r failed", kvs_key.field_name)
raise KeyValueMultiSaveError(saved_fields) raise KeyValueMultiSaveError(saved_fields)
finally:
saved_fields.append(kvs_key.field_name)
@contract(kvs_key=DjangoKeyValueStore.Key) @contract(kvs_key=DjangoKeyValueStore.Key)
def delete(self, kvs_key): def delete(self, kvs_key):
""" """
...@@ -264,15 +275,11 @@ class DjangoOrmFieldCache(object): ...@@ -264,15 +275,11 @@ class DjangoOrmFieldCache(object):
else: else:
return field_object.modified return field_object.modified
@contract(kvs_key=DjangoKeyValueStore.Key)
def _set_field_value(self, field_object, kvs_key, value):
field_object.value = json.dumps(value)
def __len__(self): def __len__(self):
return len(self._cache) return len(self._cache)
@abstractmethod @abstractmethod
def _create_object(self, kvs_key): def _create_object(self, kvs_key, value):
""" """
Create a new object to add to the cache (which should record Create a new object to add to the cache (which should record
the specified field ``value`` for the field identified by the specified field ``value`` for the field identified by
...@@ -280,6 +287,7 @@ class DjangoOrmFieldCache(object): ...@@ -280,6 +287,7 @@ class DjangoOrmFieldCache(object):
Arguments: Arguments:
kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for
value: What value to record in the field
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -324,9 +332,10 @@ class UserStateCache(object): ...@@ -324,9 +332,10 @@ class UserStateCache(object):
Cache for Scope.user_state xblock field data. Cache for Scope.user_state xblock field data.
""" """
def __init__(self, user, course_id): def __init__(self, user, course_id):
self._cache = {} self._cache = defaultdict(dict)
self.course_id = course_id self.course_id = course_id
self.user = user self.user = user
self._client = DjangoXBlockUserStateClient(self.user)
def cache_fields(self, fields, xblocks, aside_types): # pylint: disable=unused-argument def cache_fields(self, fields, xblocks, aside_types): # pylint: disable=unused-argument
""" """
...@@ -338,15 +347,12 @@ class UserStateCache(object): ...@@ -338,15 +347,12 @@ class UserStateCache(object):
xblocks (list of :class:`XBlock`): XBlocks to cache fields for. xblocks (list of :class:`XBlock`): XBlocks to cache fields for.
aside_types (list of str): Aside types to cache fields for. aside_types (list of str): Aside types to cache fields for.
""" """
query = StudentModule.objects.chunked_filter( block_field_state = self._client.get_many(
'module_state_key__in', self.user.username,
_all_usage_keys(xblocks, aside_types), _all_usage_keys(xblocks, aside_types),
course_id=self.course_id,
student=self.user.pk,
) )
for field_object in query: for usage_key, field_state in block_field_state:
cache_key = field_object.module_state_key.map_into_course(self.course_id) self._cache[usage_key] = field_state
self._cache[cache_key] = field_object
@contract(kvs_key=DjangoKeyValueStore.Key) @contract(kvs_key=DjangoKeyValueStore.Key)
def set(self, kvs_key, value): def set(self, kvs_key, value):
...@@ -369,12 +375,11 @@ class UserStateCache(object): ...@@ -369,12 +375,11 @@ class UserStateCache(object):
Returns: datetime if there was a modified date, or None otherwise Returns: datetime if there was a modified date, or None otherwise
""" """
field_object = self._cache.get(self._cache_key_for_kvs_key(kvs_key)) return self._client.get_mod_date(
self.user.username,
if field_object is None: kvs_key.block_scope_id,
return None fields=[kvs_key.field_name],
else: ).get(kvs_key.field_name)
return field_object.modified
@contract(kv_dict="dict(DjangoKeyValueStore_Key: *)") @contract(kv_dict="dict(DjangoKeyValueStore_Key: *)")
def set_many(self, kv_dict): def set_many(self, kv_dict):
...@@ -385,25 +390,21 @@ class UserStateCache(object): ...@@ -385,25 +390,21 @@ class UserStateCache(object):
kv_dict (dict): A dictionary mapping :class:`~DjangoKeyValueStore.Key` kv_dict (dict): A dictionary mapping :class:`~DjangoKeyValueStore.Key`
objects to values to set. objects to values to set.
""" """
dirty_field_objects = defaultdict(set) pending_updates = defaultdict(dict)
for kvs_key, value in kv_dict.items(): for kvs_key, value in kv_dict.items():
cache_key = self._cache_key_for_kvs_key(kvs_key) cache_key = self._cache_key_for_kvs_key(kvs_key)
field_object = self._cache.get(cache_key)
if field_object is None: pending_updates[cache_key][kvs_key.field_name] = value
self._cache[cache_key] = field_object = self._create_object(kvs_key)
self._set_field_value(field_object, kvs_key, value) try:
dirty_field_objects[field_object].add(kvs_key.field_name) self._client.set_many(
self.user.username,
saved_fields = [] pending_updates
for field_object, fields in sorted(dirty_field_objects.iteritems()): )
try: except DatabaseError:
field_object.save() raise KeyValueMultiSaveError([])
saved_fields.extend(fields) finally:
except DatabaseError: self._cache.update(pending_updates)
log.exception("Saving fields %r failed", fields)
raise KeyValueMultiSaveError(saved_fields)
@contract(kvs_key=DjangoKeyValueStore.Key) @contract(kvs_key=DjangoKeyValueStore.Key)
def get(self, kvs_key): def get(self, kvs_key):
...@@ -420,9 +421,7 @@ class UserStateCache(object): ...@@ -420,9 +421,7 @@ class UserStateCache(object):
if cache_key not in self._cache: if cache_key not in self._cache:
raise KeyError(kvs_key.field_name) raise KeyError(kvs_key.field_name)
field_object = self._cache[cache_key] return self._cache[cache_key][kvs_key.field_name]
return json.loads(field_object.state)[kvs_key.field_name]
@contract(kvs_key=DjangoKeyValueStore.Key) @contract(kvs_key=DjangoKeyValueStore.Key)
def delete(self, kvs_key): def delete(self, kvs_key):
...@@ -434,15 +433,17 @@ class UserStateCache(object): ...@@ -434,15 +433,17 @@ class UserStateCache(object):
Raises: KeyError if key isn't found in the cache Raises: KeyError if key isn't found in the cache
""" """
cache_key = self._cache_key_for_kvs_key(kvs_key)
if cache_key not in self._cache:
raise KeyError(kvs_key.field_name)
field_object = self._cache.get(self._cache_key_for_kvs_key(kvs_key)) field_state = self._cache[cache_key]
if field_object is None:
if kvs_key.field_name not in field_state:
raise KeyError(kvs_key.field_name) raise KeyError(kvs_key.field_name)
state = json.loads(field_object.state) self._client.delete(self.user.username, cache_key, fields=[kvs_key.field_name])
del state[kvs_key.field_name] del field_state[kvs_key.field_name]
field_object.state = json.dumps(state)
field_object.save()
@contract(kvs_key=DjangoKeyValueStore.Key, returns=bool) @contract(kvs_key=DjangoKeyValueStore.Key, returns=bool)
def has(self, kvs_key): def has(self, kvs_key):
...@@ -454,11 +455,12 @@ class UserStateCache(object): ...@@ -454,11 +455,12 @@ class UserStateCache(object):
Returns: bool Returns: bool
""" """
field_object = self._cache.get(self._cache_key_for_kvs_key(kvs_key)) cache_key = self._cache_key_for_kvs_key(kvs_key)
if field_object is None:
return False
return kvs_key.field_name in json.loads(field_object.state) return (
cache_key in self._cache and
kvs_key.field_name in self._cache[cache_key]
)
@contract(user_id=int, usage_key=UsageKey, score="number|None", max_score="number|None") @contract(user_id=int, usage_key=UsageKey, score="number|None", max_score="number|None")
def set_score(self, user_id, usage_key, score, max_score): def set_score(self, user_id, usage_key, score, max_score):
...@@ -467,30 +469,23 @@ class UserStateCache(object): ...@@ -467,30 +469,23 @@ class UserStateCache(object):
Set the score and max_score for the specified user and xblock usage. Set the score and max_score for the specified user and xblock usage.
""" """
field_object = self._cache[usage_key] student_module, created = StudentModule.objects.get_or_create(
field_object.grade = score student_id=user_id,
field_object.max_grade = max_score module_state_key=usage_key,
field_object.save() course_id=usage_key.course_key,
defaults={
@contract(kvs_key=DjangoKeyValueStore.Key) 'grade': score,
def _set_field_value(self, field_object, kvs_key, value): 'max_grade': max_score,
field_object.value = json.dumps(value) }
)
if not created:
student_module.grade = score
student_module.max_grade = max_score
student_module.save()
def __len__(self): def __len__(self):
return len(self._cache) return len(self._cache)
def _create_object(self, kvs_key):
field_object, __ = StudentModule.objects.get_or_create(
course_id=self.course_id,
student_id=kvs_key.user_id,
module_state_key=kvs_key.block_scope_id,
defaults={
'state': json.dumps({}),
'module_type': kvs_key.block_scope_id.block_type,
},
)
return field_object
def _cache_key_for_kvs_key(self, key): def _cache_key_for_kvs_key(self, key):
""" """
Return the key used in this DjangoOrmFieldCache for the specified KeyValueStore key. Return the key used in this DjangoOrmFieldCache for the specified KeyValueStore key.
...@@ -500,12 +495,6 @@ class UserStateCache(object): ...@@ -500,12 +495,6 @@ class UserStateCache(object):
""" """
return key.block_scope_id return key.block_scope_id
@contract(kvs_key=DjangoKeyValueStore.Key)
def _set_field_value(self, field_object, kvs_key, value):
state = json.loads(field_object.state)
state[kvs_key.field_name] = value
field_object.state = json.dumps(state)
class UserStateSummaryCache(DjangoOrmFieldCache): class UserStateSummaryCache(DjangoOrmFieldCache):
""" """
...@@ -515,7 +504,7 @@ class UserStateSummaryCache(DjangoOrmFieldCache): ...@@ -515,7 +504,7 @@ class UserStateSummaryCache(DjangoOrmFieldCache):
super(UserStateSummaryCache, self).__init__() super(UserStateSummaryCache, self).__init__()
self.course_id = course_id self.course_id = course_id
def _create_object(self, kvs_key): def _create_object(self, kvs_key, value):
""" """
Create a new object to add to the cache (which should record Create a new object to add to the cache (which should record
the specified field ``value`` for the field identified by the specified field ``value`` for the field identified by
...@@ -523,12 +512,13 @@ class UserStateSummaryCache(DjangoOrmFieldCache): ...@@ -523,12 +512,13 @@ class UserStateSummaryCache(DjangoOrmFieldCache):
Arguments: Arguments:
kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for
value: The value to assign to the new field object
""" """
field_object, __ = XModuleUserStateSummaryField.objects.get_or_create( return XModuleUserStateSummaryField(
field_name=kvs_key.field_name, field_name=kvs_key.field_name,
usage_id=kvs_key.block_scope_id usage_id=kvs_key.block_scope_id,
value=value,
) )
return field_object
def _read_objects(self, fields, xblocks, aside_types): def _read_objects(self, fields, xblocks, aside_types):
""" """
...@@ -575,7 +565,7 @@ class PreferencesCache(DjangoOrmFieldCache): ...@@ -575,7 +565,7 @@ class PreferencesCache(DjangoOrmFieldCache):
super(PreferencesCache, self).__init__() super(PreferencesCache, self).__init__()
self.user = user self.user = user
def _create_object(self, kvs_key): def _create_object(self, kvs_key, value):
""" """
Create a new object to add to the cache (which should record Create a new object to add to the cache (which should record
the specified field ``value`` for the field identified by the specified field ``value`` for the field identified by
...@@ -583,13 +573,14 @@ class PreferencesCache(DjangoOrmFieldCache): ...@@ -583,13 +573,14 @@ class PreferencesCache(DjangoOrmFieldCache):
Arguments: Arguments:
kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for
value: The value to assign to the new field object
""" """
field_object, __ = XModuleStudentPrefsField.objects.get_or_create( return XModuleStudentPrefsField(
field_name=kvs_key.field_name, field_name=kvs_key.field_name,
module_type=BlockTypeKeyV1(kvs_key.block_family, kvs_key.block_scope_id), module_type=BlockTypeKeyV1(kvs_key.block_family, kvs_key.block_scope_id),
student_id=kvs_key.user_id, student_id=kvs_key.user_id,
value=value,
) )
return field_object
def _read_objects(self, fields, xblocks, aside_types): def _read_objects(self, fields, xblocks, aside_types):
""" """
...@@ -637,7 +628,7 @@ class UserInfoCache(DjangoOrmFieldCache): ...@@ -637,7 +628,7 @@ class UserInfoCache(DjangoOrmFieldCache):
super(UserInfoCache, self).__init__() super(UserInfoCache, self).__init__()
self.user = user self.user = user
def _create_object(self, kvs_key): def _create_object(self, kvs_key, value):
""" """
Create a new object to add to the cache (which should record Create a new object to add to the cache (which should record
the specified field ``value`` for the field identified by the specified field ``value`` for the field identified by
...@@ -645,12 +636,13 @@ class UserInfoCache(DjangoOrmFieldCache): ...@@ -645,12 +636,13 @@ class UserInfoCache(DjangoOrmFieldCache):
Arguments: Arguments:
kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for kvs_key (:class:`DjangoKeyValueStore.Key`): Which field to create an entry for
value: The value to assign to the new field object
""" """
field_object, __ = XModuleStudentInfoField.objects.get_or_create( return XModuleStudentInfoField(
field_name=kvs_key.field_name, field_name=kvs_key.field_name,
student_id=kvs_key.user_id, student_id=kvs_key.user_id,
value=value,
) )
return field_object
def _read_objects(self, fields, xblocks, aside_types): def _read_objects(self, fields, xblocks, aside_types):
""" """
......
...@@ -134,7 +134,10 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase): ...@@ -134,7 +134,10 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase):
def test_set_existing_field(self): def test_set_existing_field(self):
"Test that setting an existing user_state field changes the value" "Test that setting an existing user_state field changes the value"
# We are updating a problem, so we write to courseware_studentmodulehistory # We are updating a problem, so we write to courseware_studentmodulehistory
# as well as courseware_studentmodule # as well as courseware_studentmodule. We also need to read the database
# to discover if something other than the DjangoXBlockUserStateClient
# has written to the StudentModule (such as UserStateCache setting the score
# on the StudentModule).
with self.assertNumQueries(3): with self.assertNumQueries(3):
self.kvs.set(user_state_key('a_field'), 'new_value') self.kvs.set(user_state_key('a_field'), 'new_value')
self.assertEquals(1, StudentModule.objects.all().count()) self.assertEquals(1, StudentModule.objects.all().count())
...@@ -143,7 +146,10 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase): ...@@ -143,7 +146,10 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase):
def test_set_missing_field(self): def test_set_missing_field(self):
"Test that setting a new user_state field changes the value" "Test that setting a new user_state field changes the value"
# We are updating a problem, so we write to courseware_studentmodulehistory # We are updating a problem, so we write to courseware_studentmodulehistory
# as well as courseware_studentmodule # as well as courseware_studentmodule. We also need to read the database
# to discover if something other than the DjangoXBlockUserStateClient
# has written to the StudentModule (such as UserStateCache setting the score
# on the StudentModule).
with self.assertNumQueries(3): with self.assertNumQueries(3):
self.kvs.set(user_state_key('not_a_field'), 'new_value') self.kvs.set(user_state_key('not_a_field'), 'new_value')
self.assertEquals(1, StudentModule.objects.all().count()) self.assertEquals(1, StudentModule.objects.all().count())
...@@ -152,7 +158,10 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase): ...@@ -152,7 +158,10 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase):
def test_delete_existing_field(self): def test_delete_existing_field(self):
"Test that deleting an existing field removes it from the StudentModule" "Test that deleting an existing field removes it from the StudentModule"
# We are updating a problem, so we write to courseware_studentmodulehistory # We are updating a problem, so we write to courseware_studentmodulehistory
# as well as courseware_studentmodule # as well as courseware_studentmodule. We also need to read the database
# to discover if something other than the DjangoXBlockUserStateClient
# has written to the StudentModule (such as UserStateCache setting the score
# on the StudentModule).
with self.assertNumQueries(3): with self.assertNumQueries(3):
self.kvs.delete(user_state_key('a_field')) self.kvs.delete(user_state_key('a_field'))
self.assertEquals(1, StudentModule.objects.all().count()) self.assertEquals(1, StudentModule.objects.all().count())
...@@ -190,6 +199,9 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase): ...@@ -190,6 +199,9 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase):
# Scope.user_state is stored in a single row in the database, so we only # Scope.user_state is stored in a single row in the database, so we only
# need to send a single update to that table. # need to send a single update to that table.
# We also are updating a problem, so we write to courseware student module history # We also are updating a problem, so we write to courseware student module history
# We also need to read the database to discover if something other than the
# DjangoXBlockUserStateClient has written to the StudentModule (such as
# UserStateCache setting the score on the StudentModule).
with self.assertNumQueries(3): with self.assertNumQueries(3):
self.kvs.set_many(kv_dict) self.kvs.set_many(kv_dict)
...@@ -207,7 +219,7 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase): ...@@ -207,7 +219,7 @@ class TestStudentModuleStorage(OtherUserFailureTestMixin, TestCase):
with patch('django.db.models.Model.save', side_effect=DatabaseError): with patch('django.db.models.Model.save', side_effect=DatabaseError):
with self.assertRaises(KeyValueMultiSaveError) as exception_context: with self.assertRaises(KeyValueMultiSaveError) as exception_context:
self.kvs.set_many(kv_dict) self.kvs.set_many(kv_dict)
self.assertEquals(len(exception_context.exception.saved_field_names), 0) self.assertEquals(exception_context.exception.saved_field_names, [])
@attr('shard_1') @attr('shard_1')
...@@ -234,8 +246,11 @@ class TestMissingStudentModule(TestCase): ...@@ -234,8 +246,11 @@ class TestMissingStudentModule(TestCase):
self.assertEquals(0, StudentModule.objects.all().count()) self.assertEquals(0, StudentModule.objects.all().count())
# We are updating a problem, so we write to courseware_studentmodulehistory # We are updating a problem, so we write to courseware_studentmodulehistory
# as well as courseware_studentmodule # as well as courseware_studentmodule. We also need to read the database
with self.assertNumQueries(6): # to discover if something other than the DjangoXBlockUserStateClient
# has written to the StudentModule (such as UserStateCache setting the score
# on the StudentModule).
with self.assertNumQueries(3):
self.kvs.set(user_state_key('a_field'), 'a_value') self.kvs.set(user_state_key('a_field'), 'a_value')
self.assertEquals(1, sum(len(cache) for cache in self.field_data_cache.cache.values())) self.assertEquals(1, sum(len(cache) for cache in self.field_data_cache.cache.values()))
...@@ -289,7 +304,7 @@ class StorageTestBase(object): ...@@ -289,7 +304,7 @@ class StorageTestBase(object):
self.kvs = DjangoKeyValueStore(self.field_data_cache) self.kvs = DjangoKeyValueStore(self.field_data_cache)
def test_set_and_get_existing_field(self): def test_set_and_get_existing_field(self):
with self.assertNumQueries(2): with self.assertNumQueries(1):
self.kvs.set(self.key_factory('existing_field'), 'test_value') self.kvs.set(self.key_factory('existing_field'), 'test_value')
with self.assertNumQueries(0): with self.assertNumQueries(0):
self.assertEquals('test_value', self.kvs.get(self.key_factory('existing_field'))) self.assertEquals('test_value', self.kvs.get(self.key_factory('existing_field')))
...@@ -306,14 +321,14 @@ class StorageTestBase(object): ...@@ -306,14 +321,14 @@ class StorageTestBase(object):
def test_set_existing_field(self): def test_set_existing_field(self):
"Test that setting an existing field changes the value" "Test that setting an existing field changes the value"
with self.assertNumQueries(2): with self.assertNumQueries(1):
self.kvs.set(self.key_factory('existing_field'), 'new_value') self.kvs.set(self.key_factory('existing_field'), 'new_value')
self.assertEquals(1, self.storage_class.objects.all().count()) self.assertEquals(1, self.storage_class.objects.all().count())
self.assertEquals('new_value', json.loads(self.storage_class.objects.all()[0].value)) self.assertEquals('new_value', json.loads(self.storage_class.objects.all()[0].value))
def test_set_missing_field(self): def test_set_missing_field(self):
"Test that setting a new field changes the value" "Test that setting a new field changes the value"
with self.assertNumQueries(4): with self.assertNumQueries(1):
self.kvs.set(self.key_factory('missing_field'), 'new_value') self.kvs.set(self.key_factory('missing_field'), 'new_value')
self.assertEquals(2, self.storage_class.objects.all().count()) self.assertEquals(2, self.storage_class.objects.all().count())
self.assertEquals('old_value', json.loads(self.storage_class.objects.get(field_name='existing_field').value)) self.assertEquals('old_value', json.loads(self.storage_class.objects.get(field_name='existing_field').value))
...@@ -355,7 +370,7 @@ class StorageTestBase(object): ...@@ -355,7 +370,7 @@ class StorageTestBase(object):
# Each field is a separate row in the database, hence # Each field is a separate row in the database, hence
# a separate query # a separate query
with self.assertNumQueries(len(kv_dict) * 3): with self.assertNumQueries(len(kv_dict)):
self.kvs.set_many(kv_dict) self.kvs.set_many(kv_dict)
for key in kv_dict: for key in kv_dict:
self.assertEquals(self.kvs.get(key), kv_dict[key]) self.assertEquals(self.kvs.get(key), kv_dict[key])
...@@ -363,8 +378,8 @@ class StorageTestBase(object): ...@@ -363,8 +378,8 @@ class StorageTestBase(object):
def test_set_many_failure(self): def test_set_many_failure(self):
"""Test that setting many regular fields with a DB error """ """Test that setting many regular fields with a DB error """
kv_dict = self.construct_kv_dict() kv_dict = self.construct_kv_dict()
with self.assertNumQueries(6): for key in kv_dict:
for key in kv_dict: with self.assertNumQueries(1):
self.kvs.set(key, 'test value') self.kvs.set(key, 'test value')
with patch('django.db.models.Model.save', side_effect=[None, DatabaseError]): with patch('django.db.models.Model.save', side_effect=[None, DatabaseError]):
...@@ -372,8 +387,7 @@ class StorageTestBase(object): ...@@ -372,8 +387,7 @@ class StorageTestBase(object):
self.kvs.set_many(kv_dict) self.kvs.set_many(kv_dict)
exception = exception_context.exception exception = exception_context.exception
self.assertEquals(len(exception.saved_field_names), 1) self.assertEquals(exception.saved_field_names, ['existing_field', 'other_existing_field'])
self.assertIn(exception.saved_field_names[0], ('existing_field', 'other_existing_field'))
class TestUserStateSummaryStorage(StorageTestBase, TestCase): class TestUserStateSummaryStorage(StorageTestBase, TestCase):
......
...@@ -40,7 +40,6 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient): ...@@ -40,7 +40,6 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient):
pass pass
def __init__(self, user): def __init__(self, user):
self._student_module_cache = {}
self.user = user self.user = user
def get(self, username, block_key, scope=Scope.user_state, fields=None): def get(self, username, block_key, scope=Scope.user_state, fields=None):
...@@ -93,6 +92,28 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient): ...@@ -93,6 +92,28 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient):
assert self.user.username == username assert self.user.username == username
return self.delete_many(username, [block_key], scope, fields=fields) return self.delete_many(username, [block_key], scope, fields=fields)
@contract(username="basestring", block_key=UsageKey, scope=ScopeBase, fields="seq(basestring)|set(basestring)|None")
def get_mod_date(self, username, block_key, scope=Scope.user_state, fields=None):
"""
Get the last modification date for fields from the specified blocks.
Arguments:
username: The name of the user whose state should be deleted
block_key (UsageKey): The UsageKey identifying which xblock modification dates to retrieve.
scope (Scope): The scope to retrieve from.
fields: A list of fields to query. If None, delete all stored fields.
Specific implementations are free to return the same modification date
for all fields, if they don't store changes individually per field.
Implementations may omit fields for which data has not been stored.
Returns: list a dict of {field_name: modified_date} for each selected field.
"""
results = self.get_mod_date_many(username, [block_key], scope, fields=fields)
return {
field: date for (_, field, date) in results
}
@contract(username="basestring", block_keys="seq(UsageKey)|set(UsageKey)")
def _get_field_objects(self, username, block_keys): def _get_field_objects(self, username, block_keys):
""" """
Retrieve the :class:`~StudentModule`s for the supplied ``username`` and ``block_keys``. Retrieve the :class:`~StudentModule`s for the supplied ``username`` and ``block_keys``.
...@@ -108,23 +129,15 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient): ...@@ -108,23 +129,15 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient):
) )
for course_key, usage_keys in by_course: for course_key, usage_keys in by_course:
not_cached = []
for usage_key in usage_keys:
if usage_key in self._student_module_cache:
yield self._student_module_cache[usage_key]
else:
not_cached.append(usage_key)
query = StudentModule.objects.chunked_filter( query = StudentModule.objects.chunked_filter(
'module_state_key__in', 'module_state_key__in',
not_cached, usage_keys,
student__username=username, student__username=username,
course_id=course_key, course_id=course_key,
) )
for student_module in query: for student_module in query:
usage_key = student_module.module_state_key.map_into_course(student_module.course_id) usage_key = student_module.module_state_key.map_into_course(student_module.course_id)
self._student_module_cache[usage_key] = student_module
yield (student_module, usage_key) yield (student_module, usage_key)
...@@ -170,14 +183,10 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient): ...@@ -170,14 +183,10 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient):
if scope != Scope.user_state: if scope != Scope.user_state:
raise ValueError("Only Scope.user_state is supported") raise ValueError("Only Scope.user_state is supported")
field_objects = self._get_field_objects(username, block_keys_to_state.keys()) # We do a find_or_create for every block (rather than re-using field objects
for field_object in field_objects: # that were queried in get_many) so that if the score has
usage_key = field_object.module_state_key.map_into_course(field_object.course_id) # been changed by some other piece of the code, we don't overwrite
current_state = json.loads(field_object.state) # that score.
current_state.update(block_keys_to_state.pop(usage_key))
field_object.state = json.dumps(current_state)
field_object.save()
for usage_key, state in block_keys_to_state.items(): for usage_key, state in block_keys_to_state.items():
student_module, created = StudentModule.objects.get_or_create( student_module, created = StudentModule.objects.get_or_create(
student=self.user, student=self.user,
...@@ -227,6 +236,39 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient): ...@@ -227,6 +236,39 @@ class DjangoXBlockUserStateClient(XBlockUserStateClient):
# We just read this object, so we know that we can do an update # We just read this object, so we know that we can do an update
student_module.save(force_update=True) student_module.save(force_update=True)
@contract(
username="basestring",
block_keys="seq(UsageKey)|set(UsageKey)",
scope=ScopeBase,
fields="seq(basestring)|set(basestring)|None"
)
def get_mod_date_many(self, username, block_keys, scope=Scope.user_state, fields=None):
"""
Get the last modification date for fields from the specified blocks.
Arguments:
username: The name of the user whose state should be deleted
block_key (UsageKey): The UsageKey identifying which xblock modification dates to retrieve.
scope (Scope): The scope to retrieve from.
fields: A list of fields to query. If None, delete all stored fields.
Specific implementations are free to return the same modification date
for all fields, if they don't store changes individually per field.
Implementations may omit fields for which data has not been stored.
Yields: tuples of (block, field_name, modified_date) for each selected field.
"""
assert self.user.username == username
if scope != Scope.user_state:
raise ValueError("Only Scope.user_state is supported")
field_objects = self._get_field_objects(username, block_keys)
for field_object, usage_key in field_objects:
if field_object.state is None:
continue
for field in json.loads(field_object.state):
yield (usage_key, field, field_object.modified)
def get_history(self, username, block_key, scope=Scope.user_state): def get_history(self, username, block_key, scope=Scope.user_state):
"""We don't guarantee that history for many blocks will be fast.""" """We don't guarantee that history for many blocks will be fast."""
assert self.user.username == username assert self.user.username == username
......
...@@ -147,7 +147,7 @@ class UserCourseStatus(views.APIView): ...@@ -147,7 +147,7 @@ class UserCourseStatus(views.APIView):
scope=Scope.user_state, scope=Scope.user_state,
user_id=request.user.id, user_id=request.user.id,
block_scope_id=course.location, block_scope_id=course.location,
field_name=None field_name='position'
) )
original_store_date = field_data_cache.last_modified(key) original_store_date = field_data_cache.last_modified(key)
if original_store_date is not None and modification_date < original_store_date: if original_store_date is not None and modification_date < original_store_date:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment