Commit f10bcc1e by Nimisha Asthagiri

Fix get_courses.

parent abbfa95e
...@@ -8,11 +8,11 @@ In this way, courses can be served up both - say - XMLModuleStore or MongoModule ...@@ -8,11 +8,11 @@ In this way, courses can be served up both - say - XMLModuleStore or MongoModule
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
import itertools import itertools
import functools
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from opaque_keys.edx.locations import SlashSeparatedCourseKey from opaque_keys.edx.locations import SlashSeparatedCourseKey
from opaque_keys.edx.locator import Locator
from . import ModuleStoreWriteBase from . import ModuleStoreWriteBase
from . import ModuleStoreEnum from . import ModuleStoreEnum
...@@ -25,40 +25,63 @@ log = logging.getLogger(__name__) ...@@ -25,40 +25,63 @@ log = logging.getLogger(__name__)
def strip_key(func): def strip_key(func):
"""
A decorator for stripping version and branch information from return values that are, or contain, locations.
Additionally, the decorated function is called with an optional 'field_decorator' parameter that can be used
to strip any location(-containing) fields, which are not directly returned by the function.
The behavior can be controlled by passing 'remove_version' and 'remove_branch' booleans to the decorated
function's kwargs.
"""
@functools.wraps(func)
def inner(*args, **kwargs): def inner(*args, **kwargs):
"""
Supported kwargs:
remove_version - If True, calls 'version_agnostic' on all return values, including those in lists and dicts.
remove_branch - If True, calls 'for_branch(None)' on all return values, including those in lists and dicts.
Note: The 'field_decorator' parameter passed to the decorated function is a function that honors the
values of these kwargs.
"""
# remove version and branch, by default # remove version and branch, by default
rem_vers = kwargs.pop('remove_version', True) rem_vers = kwargs.pop('remove_version', True)
rem_branch = kwargs.pop('remove_branch', False) rem_branch = kwargs.pop('remove_branch', False)
# helper function for stripping individual values
def strip_key_func(val): def strip_key_func(val):
"""
Strips the version and branch information according to the settings of rem_vers and rem_branch.
Recursively calls this function if the given value has a 'location' attribute.
"""
retval = val retval = val
if isinstance(retval, Locator): if rem_vers and hasattr(retval, 'version_agnostic'):
if rem_vers:
retval = retval.version_agnostic() retval = retval.version_agnostic()
if rem_branch: if rem_branch and hasattr(retval, 'for_branch'):
retval = retval.for_branch(None) retval = retval.for_branch(None)
if hasattr(retval, 'location'):
retval.location = strip_key_func(retval.location)
return retval return retval
# decorator for field values # function for stripping both, collection of, and individual, values
def strip_key_field_decorator(field_value): def strip_key_collection(field_value):
"""
Calls strip_key_func for each element in the given value.
"""
if rem_vers or rem_branch: if rem_vers or rem_branch:
if isinstance(field_value, list): if isinstance(field_value, list):
field_value = [strip_key_func(fv) for fv in field_value] field_value = [strip_key_func(fv) for fv in field_value]
elif isinstance(field_value, dict): elif isinstance(field_value, dict):
for key, val in field_value.iteritems(): for key, val in field_value.iteritems():
field_value[key] = strip_key_func(val) field_value[key] = strip_key_func(val)
elif hasattr(field_value, 'location'):
field_value.location = strip_key_func(field_value.location)
else: else:
field_value = strip_key_func(field_value) field_value = strip_key_func(field_value)
return field_value return field_value
# call the function # call the decorated function
retval = func(field_decorator=strip_key_field_decorator, *args, **kwargs) retval = func(field_decorator=strip_key_collection, *args, **kwargs)
# return the "decorated" value # strip the return value
return strip_key_field_decorator(retval) return strip_key_collection(retval)
return inner return inner
...@@ -219,10 +242,15 @@ class MixedModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase): ...@@ -219,10 +242,15 @@ class MixedModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase):
''' '''
Returns a list containing the top level XModuleDescriptors of the courses in this modulestore. Returns a list containing the top level XModuleDescriptors of the courses in this modulestore.
''' '''
courses = [] courses = {}
for store in self.modulestores: for store in self.modulestores:
courses.extend(store.get_courses(**kwargs)) # filter out ones which were fetched from earlier stores but locations may not be ==
return courses for course in store.get_courses(**kwargs):
course_id = self._clean_course_id_for_mapping(course.id)
if course_id not in courses:
# course is indeed unique. save it in result
courses[course_id] = course
return courses.values()
def make_course_key(self, org, course, run): def make_course_key(self, org, course, run):
""" """
......
...@@ -91,7 +91,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -91,7 +91,7 @@ class TestMixedModuleStore(unittest.TestCase):
""" """
AssertEqual replacement for CourseLocator AssertEqual replacement for CourseLocator
""" """
if loc1.version_agnostic() != loc2.version_agnostic(): if loc1.for_branch(None) != loc2.for_branch(None):
self.fail(self._formatMessage(msg, u"{} != {}".format(unicode(loc1), unicode(loc2)))) self.fail(self._formatMessage(msg, u"{} != {}".format(unicode(loc1), unicode(loc2))))
def setUp(self): def setUp(self):
...@@ -125,13 +125,13 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -125,13 +125,13 @@ class TestMixedModuleStore(unittest.TestCase):
# create course # create course
self.course = self.store.create_course(course_key.org, course_key.course, course_key.run, self.user_id) self.course = self.store.create_course(course_key.org, course_key.course, course_key.run, self.user_id)
if isinstance(self.course.id, CourseLocator): if isinstance(self.course.id, CourseLocator):
self.course_locations[self.MONGO_COURSEID] = self.course.location.version_agnostic() self.course_locations[self.MONGO_COURSEID] = self.course.location
else: else:
self.assertEqual(self.course.id, course_key) self.assertEqual(self.course.id, course_key)
# create chapter # create chapter
chapter = self.store.create_child(self.user_id, self.course.location, 'chapter', block_id='Overview') chapter = self.store.create_child(self.user_id, self.course.location, 'chapter', block_id='Overview')
self.writable_chapter_location = chapter.location.version_agnostic() self.writable_chapter_location = chapter.location
def _create_block_hierarchy(self): def _create_block_hierarchy(self):
""" """
...@@ -176,13 +176,13 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -176,13 +176,13 @@ class TestMixedModuleStore(unittest.TestCase):
def create_sub_tree(parent, block_info): def create_sub_tree(parent, block_info):
block = self.store.create_child( block = self.store.create_child(
self.user_id, parent.location.version_agnostic(), self.user_id, parent.location,
block_info.category, block_id=block_info.display_name, block_info.category, block_id=block_info.display_name,
fields={'display_name': block_info.display_name}, fields={'display_name': block_info.display_name},
) )
for tree in block_info.sub_tree: for tree in block_info.sub_tree:
create_sub_tree(block, tree) create_sub_tree(block, tree)
setattr(self, block_info.field_name, block.location.version_agnostic()) setattr(self, block_info.field_name, block.location)
for tree in trees: for tree in trees:
create_sub_tree(self.course, tree) create_sub_tree(self.course, tree)
...@@ -386,7 +386,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -386,7 +386,7 @@ class TestMixedModuleStore(unittest.TestCase):
# Create dummy direct only xblocks # Create dummy direct only xblocks
chapter = self.store.create_item( chapter = self.store.create_item(
self.user_id, self.user_id,
test_course.id.version_agnostic(), test_course.id,
'chapter', 'chapter',
block_id='vertical_container' block_id='vertical_container'
) )
...@@ -407,7 +407,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -407,7 +407,7 @@ class TestMixedModuleStore(unittest.TestCase):
# Create a dummy component to test against # Create a dummy component to test against
xblock = self.store.create_item( xblock = self.store.create_item(
self.user_id, self.user_id,
test_course.id.version_agnostic(), test_course.id,
'vertical', 'vertical',
block_id='test_vertical' block_id='test_vertical'
) )
...@@ -522,7 +522,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -522,7 +522,7 @@ class TestMixedModuleStore(unittest.TestCase):
revision=ModuleStoreEnum.RevisionOption.draft_preferred revision=ModuleStoreEnum.RevisionOption.draft_preferred
) )
self.store.publish(private_vert.location.version_agnostic(), self.user_id) self.store.publish(private_vert.location, self.user_id)
private_leaf.display_name = 'change me' private_leaf.display_name = 'change me'
private_leaf = self.store.update_item(private_leaf, self.user_id) private_leaf = self.store.update_item(private_leaf, self.user_id)
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID))
...@@ -538,11 +538,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -538,11 +538,7 @@ class TestMixedModuleStore(unittest.TestCase):
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID))
with check_mongo_calls(mongo_store, max_find, max_send): with check_mongo_calls(mongo_store, max_find, max_send):
courses = self.store.get_courses() courses = self.store.get_courses()
course_ids = [ course_ids = [course.location for course in courses]
course.location.version_agnostic()
if hasattr(course.location, 'version_agnostic') else course.location
for course in courses
]
self.assertEqual(len(courses), 3, "Not 3 courses: {}".format(course_ids)) self.assertEqual(len(courses), 3, "Not 3 courses: {}".format(course_ids))
self.assertIn(self.course_locations[self.MONGO_COURSEID], course_ids) self.assertIn(self.course_locations[self.MONGO_COURSEID], course_ids)
self.assertIn(self.course_locations[self.XML_COURSEID1], course_ids) self.assertIn(self.course_locations[self.XML_COURSEID1], course_ids)
...@@ -622,7 +618,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -622,7 +618,7 @@ class TestMixedModuleStore(unittest.TestCase):
self._create_block_hierarchy() self._create_block_hierarchy()
# publish the course # publish the course
self.course = self.store.publish(self.course.location.version_agnostic(), self.user_id) self.course = self.store.publish(self.course.location, self.user_id)
# make drafts of verticals # make drafts of verticals
self.store.convert_to_draft(self.vertical_x1a, self.user_id) self.store.convert_to_draft(self.vertical_x1a, self.user_id)
...@@ -648,7 +644,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -648,7 +644,7 @@ class TestMixedModuleStore(unittest.TestCase):
]) ])
# publish the course again # publish the course again
self.store.publish(self.course.location.version_agnostic(), self.user_id) self.store.publish(self.course.location, self.user_id)
self.verify_get_parent_locations_results([ self.verify_get_parent_locations_results([
(child_to_move_location, new_parent_location, None), (child_to_move_location, new_parent_location, None),
(child_to_move_location, new_parent_location, ModuleStoreEnum.RevisionOption.draft_preferred), (child_to_move_location, new_parent_location, ModuleStoreEnum.RevisionOption.draft_preferred),
...@@ -888,7 +884,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -888,7 +884,7 @@ class TestMixedModuleStore(unittest.TestCase):
self.initdb(default_ms) self.initdb(default_ms)
block = self.store.create_item( block = self.store.create_item(
self.user_id, self.user_id,
self.course.location.version_agnostic().course_key, self.course.location.course_key,
'problem' 'problem'
) )
self.assertEqual(self.user_id, block.edited_by) self.assertEqual(self.user_id, block.edited_by)
...@@ -899,7 +895,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -899,7 +895,7 @@ class TestMixedModuleStore(unittest.TestCase):
self.initdb(default_ms) self.initdb(default_ms)
block = self.store.create_item( block = self.store.create_item(
self.user_id, self.user_id,
self.course.location.version_agnostic().course_key, self.course.location.course_key,
'problem' 'problem'
) )
self.assertEqual(self.user_id, block.subtree_edited_by) self.assertEqual(self.user_id, block.subtree_edited_by)
...@@ -944,7 +940,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -944,7 +940,7 @@ class TestMixedModuleStore(unittest.TestCase):
self._create_block_hierarchy() self._create_block_hierarchy()
# publish # publish
self.store.publish(self.course.location.version_agnostic(), self.user_id) self.store.publish(self.course.location, self.user_id)
published_xblock = self.store.get_item( published_xblock = self.store.get_item(
self.vertical_x1a, self.vertical_x1a,
revision=ModuleStoreEnum.RevisionOption.published_only revision=ModuleStoreEnum.RevisionOption.published_only
...@@ -980,7 +976,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -980,7 +976,7 @@ class TestMixedModuleStore(unittest.TestCase):
# start off as Private # start off as Private
item = self.store.create_child(self.user_id, self.writable_chapter_location, 'problem', 'test_compute_publish_state') item = self.store.create_child(self.user_id, self.writable_chapter_location, 'problem', 'test_compute_publish_state')
item_location = item.location.version_agnostic() item_location = item.location
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID))
with check_mongo_calls(mongo_store, max_find, max_send): with check_mongo_calls(mongo_store, max_find, max_send):
self.assertEquals(self.store.compute_publish_state(item), PublishState.private) self.assertEquals(self.store.compute_publish_state(item), PublishState.private)
...@@ -1030,13 +1026,13 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -1030,13 +1026,13 @@ class TestMixedModuleStore(unittest.TestCase):
test_course = self.store.create_course('testx', 'GreekHero', 'test_run', self.user_id) test_course = self.store.create_course('testx', 'GreekHero', 'test_run', self.user_id)
self.assertEqual(self.store.compute_publish_state(test_course), PublishState.public) self.assertEqual(self.store.compute_publish_state(test_course), PublishState.public)
test_course_key = test_course.id.version_agnostic() test_course_key = test_course.id
# test create_item of direct-only category to make sure we are autopublishing # test create_item of direct-only category to make sure we are autopublishing
chapter = self.store.create_item(self.user_id, test_course_key, 'chapter', 'Overview') chapter = self.store.create_item(self.user_id, test_course_key, 'chapter', 'Overview')
self.assertEqual(self.store.compute_publish_state(chapter), PublishState.public) self.assertEqual(self.store.compute_publish_state(chapter), PublishState.public)
chapter_location = chapter.location.version_agnostic() chapter_location = chapter.location
# test create_child of direct-only category to make sure we are autopublishing # test create_child of direct-only category to make sure we are autopublishing
sequential = self.store.create_child(self.user_id, chapter_location, 'sequential', 'Sequence') sequential = self.store.create_child(self.user_id, chapter_location, 'sequential', 'Sequence')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment