Commit d96fa776 by Nimisha Asthagiri

Course API and Transformer

parent a3a48a98
......@@ -905,6 +905,15 @@ class ModuleStoreRead(ModuleStoreAssetBase):
"""
pass
@abstractmethod
def make_course_usage_key(self, course_key):
"""
Return a valid :class:`~opaque_keys.edx.keys.UsageKey` for this modulestore
that matches the supplied course_key.
"""
pass
@abstractmethod
def get_courses(self, **kwargs):
'''
......
......@@ -313,6 +313,15 @@ class MixedModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase):
# Otherwise, return the key created by the default store
return self.default_modulestore.make_course_key(org, course, run)
def make_course_usage_key(self, course_key):
"""
Return a valid :class:`~opaque_keys.edx.keys.UsageKey` for the modulestore
that matches the supplied course_key.
"""
assert isinstance(course_key, CourseKey)
store = self._get_modulestore_for_courselike(course_key)
return store.make_course_usage_key(course_key)
@strip_key
def get_course(self, course_key, depth=0, **kwargs):
"""
......
......@@ -1033,6 +1033,13 @@ class MongoModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase, Mongo
"""
return CourseLocator(org, course, run, deprecated=True)
def make_course_usage_key(self, course_key):
"""
Return a valid :class:`~opaque_keys.edx.keys.UsageKey` for this modulestore
that matches the supplied course_key.
"""
return BlockUsageLocator(course_key, 'course', course_key.run)
def get_course(self, course_key, depth=0, **kwargs):
"""
Get the course with the given courseid (org/course/run)
......
......@@ -948,6 +948,13 @@ class SplitMongoModuleStore(SplitBulkWriteMixin, ModuleStoreWriteBase):
"""
return CourseLocator(org, course, run)
def make_course_usage_key(self, course_key):
"""
Return a valid :class:`~opaque_keys.edx.keys.UsageKey` for this modulestore
that matches the supplied course_key.
"""
return BlockUsageLocator(course_key, 'course', 'course')
def _get_structure(self, structure_id, depth, head_validation=True, **kwargs):
"""
Gets Course or Library by locator
......
......@@ -27,7 +27,7 @@ from xmodule.modulestore.xml_exporter import DEFAULT_CONTENT_FIELDS
from xmodule.modulestore import ModuleStoreEnum, ModuleStoreReadBase, LIBRARY_ROOT, COURSE_ROOT
from xmodule.tabs import CourseTabList
from opaque_keys.edx.locations import SlashSeparatedCourseKey, Location
from opaque_keys.edx.locator import CourseLocator, LibraryLocator
from opaque_keys.edx.locator import CourseLocator, LibraryLocator, BlockUsageLocator
from xblock.field_data import DictFieldData
from xblock.runtime import DictKeyValueStore
......@@ -821,6 +821,13 @@ class XMLModuleStore(ModuleStoreReadBase):
"""
return CourseLocator(org, course, run, deprecated=True)
def make_course_usage_key(self, course_key):
"""
Return a valid :class:`~opaque_keys.edx.keys.UsageKey` for this modulestore
that matches the supplied course_key.
"""
return BlockUsageLocator(course_key, 'course', course_key.run)
def get_courses(self, **kwargs):
"""
Returns a list of course descriptors. If there were errors on loading,
......
......@@ -765,7 +765,7 @@ class VideoDescriptor(VideoFields, VideoTranscriptsMixin, VideoStudioViewHandler
"""
return edxval_api.get_video_info_for_course_and_profiles(unicode(course_id), video_profile_names)
def student_view_json(self, context):
def student_view_data(self, context):
"""
Returns a JSON representation of the student_view of this XModule.
The contract of the JSON content is between the caller and the particular XModule.
......@@ -780,7 +780,7 @@ class VideoDescriptor(VideoFields, VideoTranscriptsMixin, VideoStudioViewHandler
# Check in VAL data first if edx_video_id exists
if self.edx_video_id:
video_profile_names = context.get("profiles", [])
video_profile_names = context.get("profiles", ["mobile_low"])
# get and cache bulk VAL data for course
val_course_data = self.get_cached_val_data_for_course(video_profile_names, self.location.course_key)
......
"""
"""
from django.contrib.auth.models import User
from django.core.exceptions import ValidationError
from django.forms import Form, CharField, Field, MultipleHiddenInput
from django.http import Http404
from rest_framework.exceptions import PermissionDenied
from courseware.access import _has_access_to_course
from xmodule.modulestore.django import modulestore
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import UsageKey
from transformers.student_view import StudentViewTransformer
from transformers.block_counts import BlockCountsTransformer
class ListField(Field):
"""
Field for a list of strings
"""
widget = MultipleHiddenInput
class BlockListGetForm(Form):
"""
A form to validate query parameters in the block list retrieval endpoint
"""
user = CharField(required=True) # TODO return all blocks if user is not specified by requesting staff user
usage_key = CharField(required=True)
requested_fields = ListField(required=False)
student_view_data = ListField(required=False)
block_counts = ListField(required=False)
depth = CharField(required=False)
def clean_requested_fields(self):
# add default requested_fields
return set(self.cleaned_data['requested_fields'] or set()) | {'type', 'display_name'}
def clean_depth(self):
value = self.cleaned_data['depth']
if not value:
return 0
elif value == "all":
return None
try:
return int(value)
except ValueError:
raise ValidationError("'{}' is not a valid depth value".format(value))
def clean_usage_key(self):
usage_key = self.cleaned_data['usage_key']
try:
usage_key = UsageKey.from_string(usage_key)
usage_key = usage_key.replace(course_key=modulestore().fill_in_run(usage_key.course_key))
except InvalidKeyError:
raise ValidationError("'{}' is not a valid usage key".format(unicode(usage_key)))
return usage_key
def clean(self):
cleaned_data = super(BlockListGetForm, self).clean()
# add additional requested_fields that are specified as separate parameters, if they were requested
for additional_field in [StudentViewTransformer.STUDENT_VIEW_DATA, BlockCountsTransformer.BLOCK_COUNTS]:
if cleaned_data.get(additional_field):
cleaned_data['requested_fields'].add(additional_field)
# validate and set user
usage_key = self.cleaned_data.get('usage_key')
if not usage_key:
return
requested_username = cleaned_data.get('user', '')
requesting_user = self.initial['request'].user
if requesting_user.username.lower() == requested_username.lower():
cleaned_data['user'] = requesting_user
else:
# the requesting user is trying to access another user's view
# verify requesting user is staff and update requested user's object
if not _has_access_to_course(requesting_user, 'staff', usage_key.course_key):
raise PermissionDenied(
"'{requesting_username}' does not have permission to access view for '{requested_username}'."
.format(requesting_username=requesting_user.username, requested_username=requested_username)
)
# get requested user object
try:
cleaned_data['user'] = User.objects.get(username=requested_username)
except (User.DoesNotExist):
raise Http404("'{username}' does not exist.".format(username=requested_username))
return cleaned_data
"""
Serializers for all Course Blocks related return objects.
"""
from rest_framework import serializers
from rest_framework.reverse import reverse
from transformers import SUPPORTED_FIELDS
class BlockSerializer(serializers.Serializer):
"""
TODO
"""
def _get_field(self, block_key, transformer, field_name):
if transformer:
return self.context['block_structure'].get_transformer_block_data(block_key, transformer, field_name)
else:
return self.context['block_structure'].get_xblock_field(block_key, field_name)
def to_native(self, block_key):
data = {
'id': unicode(block_key),
'lms_web_url': reverse(
'jump_to',
kwargs={'course_id': unicode(block_key.course_key), 'location': unicode(block_key)},
request=self.context['request'],
),
'student_view_url': reverse(
'courseware.views.render_xblock',
kwargs={'usage_key_string': unicode(block_key)},
request=self.context['request'],
),
}
for supported_field in SUPPORTED_FIELDS:
if supported_field.requested_field_name in self.context['requested_fields']:
data[supported_field.requested_field_name] = self._get_field(
block_key,
supported_field.transformer,
supported_field.block_field_name,
)
return data
"""
TODO
"""
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
class TestCourseAPI(ModuleStoreTestCase):
pass
"""
Course API Block Transformers
"""
from student_view import StudentViewTransformer
from block_counts import BlockCountsTransformer
class SupportedFieldType(object):
def __init__(self, block_field_name, transformer=None, requested_field_name=None):
self.transformer = transformer
self.block_field_name = block_field_name
self.requested_field_name = requested_field_name or block_field_name
SUPPORTED_FIELDS = (
SupportedFieldType('category', None, 'type'),
SupportedFieldType('display_name'),
SupportedFieldType('graded'),
SupportedFieldType('format'),
SupportedFieldType(StudentViewTransformer.STUDENT_VIEW_DATA, StudentViewTransformer),
SupportedFieldType(StudentViewTransformer.STUDENT_VIEW_MULTI_DEVICE, StudentViewTransformer),
# set the block_field_name to None so the entire data for the transformer is serialized
SupportedFieldType(None, BlockCountsTransformer, BlockCountsTransformer.BLOCK_COUNTS),
)
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
class BlockCountsTransformer(BlockStructureTransformer):
"""
...
"""
VERSION = 1
BLOCK_COUNTS = 'block_counts'
def __init__(self, block_types_to_count):
self.block_types_to_count = block_types_to_count
@classmethod
def collect(cls, block_structure):
"""
Collects any information that's necessary to execute this transformer's
transform method.
"""
# collect basic xblock fields
block_structure.request_xblock_fields('category')
def transform(self, user_info, block_structure):
"""
Mutates block_structure based on the given user_info.
"""
if not self.block_types_to_count:
return
for block_key in block_structure.post_order_traversal():
for block_type in self.block_types_to_count:
descendants_type_count = sum([
block_structure.get_transformer_block_data(child_key, self, block_type, 0)
for child_key in block_structure.get_children(block_key)
])
block_structure.set_transformer_block_data(
block_key,
self,
block_type,
(
descendants_type_count +
(1 if (block_structure.get_xblock_field(block_key, 'category') == block_type) else 0)
)
)
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
from course_api.blocks.transformers.block_counts import BlockCountsTransformer
from course_api.blocks.transformers.student_view import StudentViewTransformer
class BlocksAPITransformer(BlockStructureTransformer):
"""
...
"""
VERSION = 1
STUDENT_VIEW_DATA = 'student_view_data'
STUDENT_VIEW_MULTI_DEVICE = 'student_view_multi_device'
def __init__(self, block_counts, requested_student_view_data):
self.block_counts = block_counts
self.requested_student_view_data = requested_student_view_data
@classmethod
def collect(cls, block_structure):
"""
Collects any information that's necessary to execute this transformer's
transform method.
"""
# collect basic xblock fields
block_structure.request_xblock_fields('graded', 'format', 'display_name', 'category')
# collect data from containing transformers
StudentViewTransformer.collect(block_structure)
BlockCountsTransformer.collect(block_structure)
# TODO support olx_data by calling export_to_xml(?)
def transform(self, user_info, block_structure):
"""
Mutates block_structure based on the given user_info.
"""
StudentViewTransformer(self.requested_student_view_data).transform(user_info, block_structure)
BlockCountsTransformer(self.block_counts).transform(user_info, block_structure)
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
class StudentViewTransformer(BlockStructureTransformer):
"""
...
"""
VERSION = 1
STUDENT_VIEW_DATA = 'student_view_data'
STUDENT_VIEW_MULTI_DEVICE = 'student_view_multi_device'
def __init__(self, requested_student_view_data):
self.requested_student_view_data = requested_student_view_data
@classmethod
def collect(cls, block_structure):
"""
Collect student_view_multi_device and student_view_data values for each block
"""
# TODO
# File "/edx/app/edxapp/edx-platform/common/lib/xmodule/xmodule/x_module.py", line 1125, in _xmodule
# raise UndefinedContext()
# for block_key in block_structure.topological_traversal():
# block = block_structure.get_xblock(block_key)
# block_structure.set_transformer_block_data(
# block_key,
# cls,
# cls.STUDENT_VIEW_MULTI_DEVICE,
# block.has_support(getattr(block, 'student_view', None), 'multi_device'),
# )
# if getattr(block, 'student_view_data', None):
# block_structure.set_transformer_block_data(
# block_key,
# cls,
# cls.STUDENT_VIEW_DATA,
# block.student_view_data(),
# )
def transform(self, user_info, block_structure):
"""
Mutates block_structure based on the given user_info.
"""
for block_key in block_structure.post_order_traversal():
if block_structure.get_xblock_field(block_key, 'type') not in self.requested_student_view_data:
block_structure.remove_transformer_block_data(block_key, self, self.STUDENT_VIEW_DATA)
"""
Course Block API URLs
"""
from django.conf import settings
from django.conf.urls import patterns, url
from .views import CourseBlocks
urlpatterns = patterns(
'',
url(
r"^{}".format(settings.USAGE_KEY_PATTERN),
CourseBlocks.as_view(),
name="course_blocks"
),
)
from django.core.exceptions import ValidationError
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from lms.djangoapps.course_blocks.api import get_course_blocks, LMS_COURSE_TRANSFORMERS
from openedx.core.lib.api.view_utils import view_auth_classes, DeveloperErrorViewMixin
from transformers.blocks_api import BlocksAPITransformer
from transformers.block_counts import BlockCountsTransformer
from transformers.student_view import StudentViewTransformer
from .forms import BlockListGetForm
from .serializers import BlockSerializer
# TODO
# user not specified (-> staff)
# children field
# navigation to return descendants
# support hide_from_toc
@view_auth_classes()
class CourseBlocks(DeveloperErrorViewMixin, ListAPIView):
"""
**Use Case**
Returns the blocks of the course according to the requesting user's access level.
**Example requests**:
GET /api/courses/v1/blocks/<root_block_usage_id>/?depth=all
GET /api/courses/v1/blocks/<usage_id>/?
user=anjali,
&fields=graded,format,multi_device,
&block_counts=video,
&student_view_data=video,
&student_view_data.video=mobile_low
**Parameters**:
* student_view_data: (list) Indicates for which block types to return student_view_data.
Example: student_view_data=video
* block_counts: (list) Indicates for which block types to return the aggregate count of the blocks.
Example: block_counts=video,problem
* fields: (list) Indicates which additional fields to return for each block.
Default is children,graded,format,student_view_multi_device
Example: fields=graded,format,student_view_multi_device
* depth (integer or all) Indicates how deep to traverse into the blocks hierarchy.
A value of all means the entire hierarchy.
Default is 0
Example: depth=all
**Response Values**
The following fields are returned with a successful response.
* root: The ID of the root node of the course blocks.
* blocks: A dictionary that maps block usage IDs to a collection of information about each block.
Each block contains the following fields.
* id: (string) The usage ID of the block.
* type: (string) The type of block. Possible values include course, chapter, sequential, vertical, html,
problem, video, and discussion. The type can also be the name of a custom type of block used for the course.
* display_name: (string) The display name of the block.
* children: (list) If the block has child blocks, a list of IDs of the child blocks.
Returned only if the "children" input parameter is True.
* block_counts: (dict) For each block type specified in the block_counts parameter to the endpoint, the
aggregate number of blocks of that type for this block and all of its descendants.
Returned only if the "block_counts" input parameter contains this block's type.
* graded (boolean) Whether or not the block or any of its descendants is graded.
Returned only if "graded" is included in the "fields" parameter.
* format: (string) The assignment type of the block.
Possible values can be "Homework", "Lab", "Midterm Exam", and "Final Exam".
Returned only if "format" is included in the "fields" parameter.
* student_view_data: (dict) The JSON data for this block.
Returned only if the "student_view_data" input parameter contains this block's type.
* student_view_url: (string) The URL to retrieve the HTML rendering of this block's student view.
The HTML could include CSS and Javascript code. This field can be used in combination with the
student_view_multi_device field to decide whether to display this content to the user.
This URL can be used as a fallback if the student_view_data for this block type is not supported by
the client or the block.
* student_view_multi_device: (boolean) Whether or not the block's rendering obtained via block_url has support
for multiple devices.
Returned only if "student_view_multi_device" is included in the "fields" parameter.
* lms_web_url: (string) The URL to the navigational container of the xBlock on the web LMS.
This URL can be used as a further fallback if the student_view_url and the student_view_data fields
are not supported.
"""
def list(self, request, usage_key_string):
"""
REST API endpoint for listing all the blocks and/or navigation information in the course,
while regarding user access and roles.
Arguments:
request - Django request object
course - course module object
"""
requested_params = request.GET.copy()
requested_params.update({'usage_key': usage_key_string})
params = BlockListGetForm(requested_params, initial={'request': request})
if not params.is_valid():
raise ValidationError(params.errors)
blocks_api_transformer = BlocksAPITransformer(
params.cleaned_data.get(BlockCountsTransformer.BLOCK_COUNTS, []),
params.cleaned_data.get(StudentViewTransformer.STUDENT_VIEW_DATA, []),
)
blocks = get_course_blocks(
params.cleaned_data['user'],
params.cleaned_data['usage_key'],
transformers=LMS_COURSE_TRANSFORMERS | {blocks_api_transformer},
)
return Response(
BlockSerializer(
blocks,
context={
'request': request,
'block_structure': blocks,
'requested_fields': params.cleaned_data['requested_fields'],
},
many=True,
).data
)
"""
Course API URLs
"""
from django.conf import settings
from django.conf.urls import patterns, url, include
from .views import CourseView
urlpatterns = patterns(
'',
url(r'^v1/course/{}'.format(settings.COURSE_KEY_PATTERN), CourseView.as_view(), name="course_detail"),
url(r'^v1/blocks/', include('course_api.blocks.urls'))
)
"""
Course API Views
"""
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.reverse import reverse
from opaque_keys.edx.keys import CourseKey
from openedx.core.lib.api.view_utils import view_auth_classes
from xmodule.modulestore.django import modulestore
@view_auth_classes()
class CourseView(APIView):
"""
View class for the Course API
"""
def get(self, request, course_key_string):
course_key = CourseKey.from_string(course_key_string)
course_usage_key = modulestore().make_course_usage_key(course_key)
return Response({
'blocks_url': reverse(
'course_blocks',
kwargs={'usage_key_string': unicode(course_usage_key)},
request=request,
)
})
......@@ -5,7 +5,7 @@ from django.dispatch.dispatcher import receiver
from xmodule.modulestore.django import SignalHandler
from api import clear_course_from_cache
from .api import clear_course_from_cache
@receiver(SignalHandler.course_published)
......
......@@ -2,7 +2,6 @@
...
"""
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
from xmodule.course_metadata_utils import DEFAULT_START_DATE
from courseware.access_utils import check_start_date
......
......@@ -136,9 +136,7 @@ class BlockParentsMapTestCase(ModuleStoreTestCase):
):
def check_results(user, expected_accessible_blocks, blocks_with_differing_access):
self.client.login(username=user.username, password=self.password)
block_structure = get_course_blocks(
user, self.course.id, self.course.location, transformers=transformers
)
block_structure = get_course_blocks(user, self.course.location, transformers=transformers)
for i, xblock_key in enumerate(self.xblock_keys):
block_structure_result = block_structure.has_block(xblock_key)
has_access_result = bool(has_access(user, 'load', self.get_block(i)))
......
......@@ -7,12 +7,10 @@ from datetime import timedelta
from django.utils.timezone import now
from mock import patch
from courseware.access import has_access
from xmodule.course_metadata_utils import DEFAULT_START_DATE
from ..start_date import StartDateTransformer
from .test_helpers import BlockParentsMapTestCase
from xmodule.modulestore.django import modulestore
@ddt.ddt
......
......@@ -120,7 +120,6 @@ class UserPartitionTransformerTestCase(CourseStructureTestCase):
raw_block_structure = get_course_blocks(
self.user,
self.course.id,
self.course.location,
transformers={}
)
......@@ -129,7 +128,6 @@ class UserPartitionTransformerTestCase(CourseStructureTestCase):
clear_course_from_cache(self.course.id)
trans_block_structure = get_course_blocks(
self.user,
self.course.id,
self.course.location,
transformers={self.transformation}
)
......
......@@ -688,7 +688,7 @@ def _adjust_start_date_for_beta_testers(user, descriptor, course_key): # pylint
the user is looking at. Once we have proper usages and definitions per the XBlock
design, this should use the course the usage is in.
"""
return adjust_start_date(descriptor.days_early_for_beta, descriptor.start, course_key)
return adjust_start_date(user, descriptor.days_early_for_beta, descriptor.start, course_key)
def _has_instructor_access_to_location(user, location, course_key=None):
......
......@@ -869,7 +869,7 @@ class TestVideoDescriptorInitialization(BaseTestXmodule):
@ddt.ddt
class TestVideoDescriptorStudentViewJson(TestCase):
"""
Tests for the student_view_json method on VideoDescriptor.
Tests for the student_view_data method on VideoDescriptor.
"""
TEST_DURATION = 111.0
TEST_PROFILE = "mobile"
......@@ -914,15 +914,15 @@ class TestVideoDescriptorStudentViewJson(TestCase):
def get_result(self, allow_cache_miss=True):
"""
Returns the result from calling the video's student_view_json method.
Returns the result from calling the video's student_view_data method.
Arguments:
allow_cache_miss is passed in the context to the student_view_json method.
allow_cache_miss is passed in the context to the student_view_data method.
"""
context = {
"profiles": [self.TEST_PROFILE],
"allow_cache_miss": "True" if allow_cache_miss else "False"
}
return self.video.student_view_json(context)
return self.video.student_view_data(context)
def verify_result_with_fallback_url(self, result):
"""
......
......@@ -75,6 +75,9 @@ urlpatterns = (
# Course content API
url(r'^api/course_structure/', include('course_structure_api.urls', namespace='course_structure_api')),
# Course API
url(r'^api/course/', include('course_api.urls')),
# User API endpoints
url(r'^api/user/', include('openedx.core.djangoapps.user_api.urls')),
......
......@@ -24,7 +24,6 @@ from openedx.core.lib.api.authentication import (
OAuth2AuthenticationAllowInactiveUser,
)
from openedx.core.lib.api.permissions import IsUserInUrl
from util.milestones_helpers import any_unfulfilled_milestones
class DeveloperErrorViewMixin(object):
......@@ -66,7 +65,7 @@ class DeveloperErrorViewMixin(object):
if isinstance(exc, APIException):
return self.make_error_response(exc.status_code, exc.detail)
elif isinstance(exc, Http404):
return self.make_error_response(404, "Not found.")
return self.make_error_response(404, exc.message or "Not found.")
elif isinstance(exc, ValidationError):
return self.make_validation_error_response(exc)
else:
......@@ -113,6 +112,19 @@ def view_course_access(depth=0, access_action='load', check_for_milestones=False
return _decorator
class IsAuthenticatedAndNotAnonymous(IsAuthenticated):
"""
Allows access only to authenticated and non-anonymous users.
"""
def has_permission(self, request, view):
return (
# verify the user is authenticated and
super(IsAuthenticatedAndNotAnonymous, self).has_permission(request, view) and
# not anonymous
not request.user.is_anonymous()
)
def view_auth_classes(is_user=False):
"""
Function and class decorator that abstracts the authentication and permission checks for api views.
......@@ -126,7 +138,7 @@ def view_auth_classes(is_user=False):
OAuth2AuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser
)
func_or_class.permission_classes = (IsAuthenticated,)
func_or_class.permission_classes = (IsAuthenticatedAndNotAnonymous,)
if is_user:
func_or_class.permission_classes += (IsUserInUrl,)
return func_or_class
......
......@@ -6,8 +6,11 @@ from transformer import BlockStructureTransformers
def get_blocks(cache, modulestore, user_info, root_block_key, transformers):
if not BlockStructureTransformers.are_all_registered(transformers):
raise Exception("One or more requested transformers are not registered.")
unregistered_transformers = BlockStructureTransformers.find_unregistered(transformers)
if unregistered_transformers:
raise Exception(
"The following requested transformers are not registered: {}".format(unregistered_transformers)
)
# Load the cached block structure.
root_block_structure = BlockStructureFactory.create_from_cache(root_block_key, cache)
......
......@@ -27,6 +27,9 @@ class BlockStructure(object):
self._block_relations = defaultdict(self.BlockRelations)
self._add_block(self._block_relations, root_block_key)
def __iter__(self):
return self.topological_traversal()
def add_relation(self, parent_key, child_key):
self._add_relation(self._block_relations, parent_key, child_key)
......@@ -51,6 +54,14 @@ class BlockStructure(object):
predicate=predicate,
)
def post_order_traversal(self, get_result=None, predicate=None):
return traverse_post_order(
start_node=self.root_block_key,
get_children=self.get_children,
get_result=get_result,
predicate=predicate,
)
def prune(self):
# create a new block relations map with only those blocks that are still linked
pruned_block_relations = defaultdict(self.BlockRelations)
......@@ -72,12 +83,7 @@ class BlockStructure(object):
if child in pruned_block_relations:
self._add_relation(pruned_block_relations, block_key, child)
list(traverse_post_order(
start_node=self.root_block_key,
get_children=self.get_children,
get_result=do_for_each_block
))
list(self.post_order_traversal(get_result=do_for_each_block))
self._block_relations = pruned_block_relations
@classmethod
......@@ -118,14 +124,25 @@ class BlockStructureBlockData(BlockStructure):
def get_transformer_data(self, transformer, key, default=None):
return self._transformer_data.get(transformer.name(), {}).get(key, default)
def set_transformer_data(self, transformer, key, value):
self._transformer_data[transformer.name()][key] = value
def get_transformer_data_version(self, transformer):
return self.get_transformer_data(transformer, TRANSFORMER_VERSION_KEY, 0)
def get_transformer_block_data(self, usage_key, transformer, key, default=None):
def get_transformer_block_data(self, usage_key, transformer, key=None, default=None):
block_data = self._block_data_map.get(usage_key)
return block_data._transformer_data.get(
transformer.name(), {}
).get(key, default) if block_data else default
if not block_data:
return default
else:
transformer_data = block_data._transformer_data.get(transformer.name(), {})
return transformer_data.get(key, default) if key else transformer_data
def set_transformer_block_data(self, usage_key, transformer, key, value):
self._block_data_map[usage_key]._transformer_data[transformer.name()][key] = value
def remove_transformer_block_data(self, usage_key, transformer, key):
self._block_data_map[usage_key]._transformer_data.get(transformer.name(), {}).pop(key, None)
def remove_block(self, usage_key):
# Remove block from its children.
......@@ -137,10 +154,8 @@ class BlockStructureBlockData(BlockStructure):
self._block_relations[parent_key].children.remove(usage_key)
# Remove block.
if usage_key in self._block_relations:
del self._block_relations[usage_key]
if usage_key in self._block_data_map:
del self._block_data_map[usage_key]
self._block_relations.pop(usage_key, None)
self._block_data_map.pop(usage_key, None)
def remove_block_if(self, removal_condition):
def predicate(block_key):
......@@ -187,12 +202,6 @@ class BlockStructureCollectedData(BlockStructureBlockData):
raise Exception('VERSION attribute is not set on transformer {0}.', transformer.name())
self.set_transformer_data(transformer, TRANSFORMER_VERSION_KEY, transformer.VERSION)
def set_transformer_data(self, transformer, key, value):
self._transformer_data[transformer.name()][key] = value
def set_transformer_block_data(self, usage_key, transformer, key, value):
self._block_data_map[usage_key]._transformer_data[transformer.name()][key] = value
class BlockStructureFactory(object):
@classmethod
......
......@@ -23,18 +23,26 @@ class BlockStructureTransformersTestCase(TestCase):
pass
@patch('openedx.core.lib.block_cache.transformer.BlockStructureTransformers.get_available_plugins')
def test_are_all_registered(self, mock_available_transforms):
def test_find_unregistered(self, mock_available_transforms):
mock_available_transforms.return_value = {
transformer.name(): transformer
for transformer in [self.TestTransformer1, self.TestTransformer2]
}
for transformers, expected_are_all_registered in [
([], True),
([self.TestTransformer1()], True),
([self.TestTransformer1(), self.TestTransformer2()], True),
([self.UnregisteredTestTransformer3()], False),
([self.TestTransformer1(), self.UnregisteredTestTransformer3()], False),
for transformers, expected_find_unregistered in [
([], []),
([self.TestTransformer1()], []),
([self.TestTransformer1(), self.TestTransformer2()], []),
(
[self.UnregisteredTestTransformer3()],
[self.UnregisteredTestTransformer3.name()]
),
(
[self.TestTransformer1(), self.UnregisteredTestTransformer3()],
[self.UnregisteredTestTransformer3.name()]
),
]:
self.assertEquals(BlockStructureTransformers.are_all_registered(transformers), expected_are_all_registered)
self.assertSetEqual(
BlockStructureTransformers.find_unregistered(transformers), set(expected_find_unregistered)
)
......@@ -45,9 +45,7 @@ class BlockStructureTransformers(PluginManager):
return set(cls.get_available_plugins().itervalues())
@classmethod
def are_all_registered(cls, transformers):
registered_transformers = cls.get_registered_transformers()
return all(
any(transformer.name() == reg_trans.name() for reg_trans in registered_transformers)
for transformer in transformers
)
def find_unregistered(cls, transformers):
registered_transformer_names = set(reg_trans.name() for reg_trans in cls.get_registered_transformers())
requested_transformer_names = set(transformer.name() for transformer in transformers)
return requested_transformer_names - registered_transformer_names
......@@ -54,6 +54,7 @@ setup(
"user_partitions = lms.djangoapps.course_blocks.transformers.user_partitions:UserPartitionTransformer",
"split_test = lms.djangoapps.course_blocks.transformers.split_test:SplitTestTransformer",
"library_content = lms.djangoapps.course_blocks.transformers.library_content:ContentLibraryTransformer",
"blocks_api = lms.djangoapps.course_api.blocks.transformers.blocks_api:BlocksAPITransformer",
],
}
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment