Commit f0322b96 by Greg Price

Merge branch 'release'

parents 9d9dccdb c65c19d2
...@@ -46,7 +46,7 @@ class StatusDisplayStrings(object): ...@@ -46,7 +46,7 @@ class StatusDisplayStrings(object):
# Translators: This is the status for a video that the servers are currently processing # Translators: This is the status for a video that the servers are currently processing
_IN_PROGRESS = ugettext_noop("In Progress") _IN_PROGRESS = ugettext_noop("In Progress")
# Translators: This is the status for a video that the servers have successfully processed # Translators: This is the status for a video that the servers have successfully processed
_COMPLETE = ugettext_noop("Complete") _COMPLETE = ugettext_noop("Ready")
# Translators: This is the status for a video that the servers have failed to process # Translators: This is the status for a video that the servers have failed to process
_FAILED = ugettext_noop("Failed") _FAILED = ugettext_noop("Failed")
# Translators: This is the status for a video for which an invalid # Translators: This is the status for a video for which an invalid
......
...@@ -26,7 +26,17 @@ from uuid import uuid4 ...@@ -26,7 +26,17 @@ from uuid import uuid4
# import settings from LMS for consistent behavior with CMS # import settings from LMS for consistent behavior with CMS
# pylint: disable=unused-import # pylint: disable=unused-import
from lms.envs.test import (WIKI_ENABLED, PLATFORM_NAME, SITE_NAME, DEFAULT_FILE_STORAGE, MEDIA_ROOT, MEDIA_URL) from lms.envs.test import (
WIKI_ENABLED,
PLATFORM_NAME,
SITE_NAME,
DEFAULT_FILE_STORAGE,
MEDIA_ROOT,
MEDIA_URL,
# This is practically unused but needed by the oauth2_provider package, which
# some tests in common/ rely on.
OAUTH_OIDC_ISSUER,
)
# mongo connection settings # mongo connection settings
MONGO_PORT_NUM = int(os.environ.get('EDXAPP_TEST_MONGO_PORT', '27017')) MONGO_PORT_NUM = int(os.environ.get('EDXAPP_TEST_MONGO_PORT', '27017'))
......
...@@ -64,9 +64,10 @@ ...@@ -64,9 +64,10 @@
<h3 class="title-3">${_("Monitoring files as they upload")}</h3> <h3 class="title-3">${_("Monitoring files as they upload")}</h3>
<p>${_("Each video file that you upload needs to reach the video processing servers successfully before additional work can begin. You can monitor the progress of files as they upload, and try again if the upload fails.")}</p> <p>${_("Each video file that you upload needs to reach the video processing servers successfully before additional work can begin. You can monitor the progress of files as they upload, and try again if the upload fails.")}</p>
<h3 class="title-3">${_("Managing uploaded files")}</h3> <h3 class="title-3">${_("Managing uploaded files")}</h3>
<p>${_("After a file uploads successfully, automated processing begins. After automated processing begins for a file it is listed under Previous Uploads as {em_start}In Progress{em_end}. When the status is {em_start}Complete{em_end}, edX assigns a unique video ID to the video file and you can add it to your course. If something goes wrong, the {em_start}Failed{em_end} status message appears. Check for problems in the file and upload a replacement.").format(em_start='<strong>', em_end="</strong>")}</p> <p>${_("After a file uploads successfully, automated processing begins. After automated processing begins for a file it is listed under Previous Uploads as {em_start}In Progress{em_end}. You can add the video to your course as soon as it has a unique video ID and the status is {em_start}Ready{em_end}. Allow 24 hours for file processing at the external video hosting sites to complete.").format(em_start='<strong>', em_end="</strong>")}</p>
<p>${_("If something goes wrong, the {em_start}Failed{em_end} status message appears. Check for problems in your original file and upload a replacement.").format(em_start='<strong>', em_end="</strong>")}</p>
<h3 class="title-3">${_("How do I get the videos into my course?")}</h3> <h3 class="title-3">${_("How do I get the videos into my course?")}</h3>
<p>${_("After processing is complete for the video file, you copy its unique video ID. On the Course Outline page, you create or locate a video component to play this video. Edit the video component to paste the ID into the Advanced {em_start}EdX Video ID{em_end} field.").format(em_start='<strong>', em_end="</strong>")}</p> <p>${_("As soon as the status for a file is {em_start}Ready{em_end}, you can add that video to a component in your course. Copy the unique video ID. In another browser window, on the Course Outline page, create or locate a video component to play this video. Edit the video component to paste the ID into the Advanced {em_start}EdX Video ID{em_end} field. The video can play in the LMS as soon as its status is {em_start}Ready{em_end}, although processing may not be complete for all encodings and all video hosting sites.").format(em_start='<strong>', em_end="</strong>")}</p>
</div> </div>
</aside> </aside>
</section> </section>
......
...@@ -401,6 +401,8 @@ class CountryAccessRule(models.Model): ...@@ -401,6 +401,8 @@ class CountryAccessRule(models.Model):
CACHE_KEY = u"embargo.allowed_countries.{course_key}" CACHE_KEY = u"embargo.allowed_countries.{course_key}"
ALL_COUNTRIES = set(code[0] for code in list(countries))
@classmethod @classmethod
def check_country_access(cls, course_id, country): def check_country_access(cls, course_id, country):
""" """
...@@ -415,6 +417,14 @@ class CountryAccessRule(models.Model): ...@@ -415,6 +417,14 @@ class CountryAccessRule(models.Model):
True if country found in allowed country True if country found in allowed country
otherwise check given country exists in list otherwise check given country exists in list
""" """
# If the country code is not in the list of all countries,
# we don't want to automatically exclude the user.
# This can happen, for example, when GeoIP falls back
# to using a continent code because it cannot determine
# the specific country.
if country not in cls.ALL_COUNTRIES:
return True
cache_key = cls.CACHE_KEY.format(course_key=course_id) cache_key = cls.CACHE_KEY.format(course_key=course_id)
allowed_countries = cache.get(cache_key) allowed_countries = cache.get(cache_key)
if allowed_countries is None: if allowed_countries is None:
...@@ -454,7 +464,7 @@ class CountryAccessRule(models.Model): ...@@ -454,7 +464,7 @@ class CountryAccessRule(models.Model):
# If there are no whitelist countries, default to all countries # If there are no whitelist countries, default to all countries
if not whitelist_countries: if not whitelist_countries:
whitelist_countries = set(code[0] for code in list(countries)) whitelist_countries = cls.ALL_COUNTRIES
# Consolidate the rules into a single list of countries # Consolidate the rules into a single list of countries
# that have access to the course. # that have access to the course.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Tests for EmbargoMiddleware Tests for EmbargoMiddleware
""" """
from contextlib import contextmanager
import mock import mock
import unittest import unittest
import pygeoip import pygeoip
...@@ -85,9 +86,7 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): ...@@ -85,9 +86,7 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase):
self.user.profile.save() self.user.profile.save()
# Appear to make a request from an IP in a particular country # Appear to make a request from an IP in a particular country
with mock.patch.object(pygeoip.GeoIP, 'country_code_by_addr') as mock_ip: with self._mock_geoip(ip_country):
mock_ip.return_value = ip_country
# Call the API. Note that the IP address we pass in doesn't # Call the API. Note that the IP address we pass in doesn't
# matter, since we're injecting a mock for geo-location # matter, since we're injecting a mock for geo-location
result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')
...@@ -113,9 +112,7 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): ...@@ -113,9 +112,7 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase):
country=Country.objects.get(country='US') country=Country.objects.get(country='US')
) )
with mock.patch.object(pygeoip.GeoIP, 'country_code_by_addr') as mock_ip: with self._mock_geoip('US'):
mock_ip.return_value = 'US'
# The user is set to None, because the user has not been authenticated. # The user is set to None, because the user has not been authenticated.
result = embargo_api.check_course_access(self.course.id, ip_address='0.0.0.0') result = embargo_api.check_course_access(self.course.id, ip_address='0.0.0.0')
self.assertFalse(result) self.assertFalse(result)
...@@ -137,6 +134,14 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): ...@@ -137,6 +134,14 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase):
result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='FE80::0202:B3FF:FE1E:8329') result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='FE80::0202:B3FF:FE1E:8329')
self.assertTrue(result) self.assertTrue(result)
def test_country_access_fallback_to_continent_code(self):
# Simulate PyGeoIP falling back to a continent code
# instead of a country code. In this case, we should
# allow the user access.
with self._mock_geoip('EU'):
result = embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')
self.assertTrue(result)
@mock.patch.dict(settings.FEATURES, {'EMBARGO': True}) @mock.patch.dict(settings.FEATURES, {'EMBARGO': True})
def test_profile_country_db_null(self): def test_profile_country_db_null(self):
# Django country fields treat NULL values inconsistently. # Django country fields treat NULL values inconsistently.
...@@ -156,15 +161,16 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): ...@@ -156,15 +161,16 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase):
self.assertTrue(result) self.assertTrue(result)
def test_caching(self): def test_caching(self):
# Test the scenario that will go through every check with self._mock_geoip('US'):
# (restricted course, but pass all the checks) # Test the scenario that will go through every check
# This is the worst case, so it will hit all of the # (restricted course, but pass all the checks)
# caching code. # This is the worst case, so it will hit all of the
with self.assertNumQueries(3): # caching code.
embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') with self.assertNumQueries(3):
embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')
with self.assertNumQueries(0): with self.assertNumQueries(0):
embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')
def test_caching_no_restricted_courses(self): def test_caching_no_restricted_courses(self):
RestrictedCourse.objects.all().delete() RestrictedCourse.objects.all().delete()
...@@ -176,6 +182,12 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase): ...@@ -176,6 +182,12 @@ class EmbargoCheckAccessApiTests(ModuleStoreTestCase):
with self.assertNumQueries(0): with self.assertNumQueries(0):
embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0') embargo_api.check_course_access(self.course.id, user=self.user, ip_address='0.0.0.0')
@contextmanager
def _mock_geoip(self, country_code):
with mock.patch.object(pygeoip.GeoIP, 'country_code_by_addr') as mock_ip:
mock_ip.return_value = country_code
yield
@ddt.ddt @ddt.ddt
@override_settings(MODULESTORE=MODULESTORE_CONFIG) @override_settings(MODULESTORE=MODULESTORE_CONFIG)
......
"""
Forms to support third-party to first-party OAuth 2.0 access token exchange
"""
from django.contrib.auth.models import User
from django.forms import CharField
from oauth2_provider.constants import SCOPE_NAMES
import provider.constants
from provider.forms import OAuthForm, OAuthValidationError
from provider.oauth2.forms import ScopeChoiceField, ScopeMixin
from provider.oauth2.models import Client
from requests import HTTPError
from social.backends import oauth as social_oauth
from third_party_auth import pipeline
class AccessTokenExchangeForm(ScopeMixin, OAuthForm):
"""Form for access token exchange endpoint"""
access_token = CharField(required=False)
scope = ScopeChoiceField(choices=SCOPE_NAMES, required=False)
client_id = CharField(required=False)
def __init__(self, request, *args, **kwargs):
super(AccessTokenExchangeForm, self).__init__(*args, **kwargs)
self.request = request
def _require_oauth_field(self, field_name):
"""
Raise an appropriate OAuthValidationError error if the field is missing
"""
field_val = self.cleaned_data.get(field_name)
if not field_val:
raise OAuthValidationError(
{
"error": "invalid_request",
"error_description": "{} is required".format(field_name),
}
)
return field_val
def clean_access_token(self):
return self._require_oauth_field("access_token")
def clean_client_id(self):
return self._require_oauth_field("client_id")
def clean(self):
if self._errors:
return {}
backend = self.request.social_strategy.backend
if not isinstance(backend, social_oauth.BaseOAuth2):
raise OAuthValidationError(
{
"error": "invalid_request",
"error_description": "{} is not a supported provider".format(backend.name),
}
)
self.request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_API
client_id = self.cleaned_data["client_id"]
try:
client = Client.objects.get(client_id=client_id)
except Client.DoesNotExist:
raise OAuthValidationError(
{
"error": "invalid_client",
"error_description": "{} is not a valid client_id".format(client_id),
}
)
if client.client_type != provider.constants.PUBLIC:
raise OAuthValidationError(
{
# invalid_client isn't really the right code, but this mirrors
# https://github.com/edx/django-oauth2-provider/blob/edx/provider/oauth2/forms.py#L331
"error": "invalid_client",
"error_description": "{} is not a public client".format(client_id),
}
)
self.cleaned_data["client"] = client
user = None
try:
user = backend.do_auth(self.cleaned_data.get("access_token"))
except HTTPError:
pass
if user and isinstance(user, User):
self.cleaned_data["user"] = user
else:
# Ensure user does not re-enter the pipeline
self.request.social_strategy.clean_partial_pipeline()
raise OAuthValidationError(
{
"error": "invalid_grant",
"error_description": "access_token is not valid",
}
)
return self.cleaned_data
"""
A models.py is required to make this an app (until we move to Django 1.7)
"""
"""
Tests for OAuth token exchange forms
"""
import unittest
from django.conf import settings
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase
from django.test.client import RequestFactory
import httpretty
from provider import scope
import social.apps.django_app.utils as social_utils
from oauth_exchange.forms import AccessTokenExchangeForm
from oauth_exchange.tests.utils import (
AccessTokenExchangeTestMixin,
AccessTokenExchangeMixinFacebook,
AccessTokenExchangeMixinGoogle
)
class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
"""
Mixin that defines test cases for AccessTokenExchangeForm
"""
def setUp(self):
super(AccessTokenExchangeFormTest, self).setUp()
self.request = RequestFactory().post("dummy_url")
SessionMiddleware().process_request(self.request)
self.request.social_strategy = social_utils.load_strategy(self.request, self.BACKEND)
def _assert_error(self, data, expected_error, expected_error_description):
form = AccessTokenExchangeForm(request=self.request, data=data)
self.assertEqual(
form.errors,
{"error": expected_error, "error_description": expected_error_description}
)
self.assertNotIn("partial_pipeline", self.request.session)
def _assert_success(self, data, expected_scopes):
form = AccessTokenExchangeForm(request=self.request, data=data)
self.assertTrue(form.is_valid())
self.assertEqual(form.cleaned_data["user"], self.user)
self.assertEqual(form.cleaned_data["client"], self.oauth_client)
self.assertEqual(scope.to_names(form.cleaned_data["scope"]), expected_scopes)
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeFormTestFacebook(
AccessTokenExchangeFormTest,
AccessTokenExchangeMixinFacebook,
TestCase
):
"""
Tests for AccessTokenExchangeForm used with Facebook
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeFormTestGoogle(
AccessTokenExchangeFormTest,
AccessTokenExchangeMixinGoogle,
TestCase
):
"""
Tests for AccessTokenExchangeForm used with Google
"""
pass
"""
Tests for OAuth token exchange views
"""
from datetime import timedelta
import json
import mock
import unittest
from django.conf import settings
from django.core.urlresolvers import reverse
from django.test import TestCase
import httpretty
import provider.constants
from provider import scope
from provider.oauth2.models import AccessToken
from oauth_exchange.tests.utils import (
AccessTokenExchangeTestMixin,
AccessTokenExchangeMixinFacebook,
AccessTokenExchangeMixinGoogle
)
class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
"""
Mixin that defines test cases for AccessTokenExchangeView
"""
def setUp(self):
super(AccessTokenExchangeViewTest, self).setUp()
self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND})
def _assert_error(self, data, expected_error, expected_error_description):
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, 400)
self.assertEqual(response["Content-Type"], "application/json")
self.assertEqual(
json.loads(response.content),
{"error": expected_error, "error_description": expected_error_description}
)
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_success(self, data, expected_scopes):
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/json")
content = json.loads(response.content)
self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"})
self.assertEqual(content["token_type"], "Bearer")
self.assertLessEqual(
timedelta(seconds=int(content["expires_in"])),
provider.constants.EXPIRE_DELTA_PUBLIC
)
self.assertEqual(content["scope"], " ".join(expected_scopes))
token = AccessToken.objects.get(token=content["access_token"])
self.assertEqual(token.user, self.user)
self.assertEqual(token.client, self.oauth_client)
self.assertEqual(scope.to_names(token.scope), expected_scopes)
def test_single_access_token(self):
def extract_token(response):
return json.loads(response.content)["access_token"]
self._setup_provider_response(success=True)
for single_access_token in [True, False]:
with mock.patch(
"oauth_exchange.views.constants.SINGLE_ACCESS_TOKEN",
single_access_token
):
first_response = self.client.post(self.url, self.data)
second_response = self.client.post(self.url, self.data)
self.assertEqual(
extract_token(first_response) == extract_token(second_response),
single_access_token
)
def test_get_method(self):
response = self.client.get(self.url, self.data)
self.assertEqual(response.status_code, 400)
self.assertEqual(
json.loads(response.content),
{
"error": "invalid_request",
"error_description": "Only POST requests allowed.",
}
)
def test_invalid_provider(self):
url = reverse("exchange_access_token", kwargs={"backend": "invalid"})
response = self.client.post(url, self.data)
self.assertEqual(response.status_code, 404)
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestFacebook(
AccessTokenExchangeViewTest,
AccessTokenExchangeMixinFacebook,
TestCase
):
"""
Tests for AccessTokenExchangeView used with Facebook
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestGoogle(
AccessTokenExchangeViewTest,
AccessTokenExchangeMixinGoogle,
TestCase
):
"""
Tests for AccessTokenExchangeView used with Google
"""
pass
"""
Test utilities for OAuth access token exchange
"""
import json
import httpretty
import provider.constants
from provider.oauth2.models import Client
from social.apps.django_app.default.models import UserSocialAuth
from student.tests.factories import UserFactory
class AccessTokenExchangeTestMixin(object):
"""
A mixin to define test cases for access token exchange. The following
methods must be implemented by subclasses:
* _assert_error(data, expected_error, expected_error_description)
* _assert_success(data, expected_scopes)
"""
def setUp(self):
super(AccessTokenExchangeTestMixin, self).setUp()
self.client_id = "test_client_id"
self.oauth_client = Client.objects.create(
client_id=self.client_id,
client_type=provider.constants.PUBLIC
)
self.social_uid = "test_social_uid"
self.user = UserFactory()
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
self.access_token = "test_access_token"
# Initialize to minimal data
self.data = {
"access_token": self.access_token,
"client_id": self.client_id,
}
def _setup_provider_response(self, success):
"""
Register a mock response for the third party user information endpoint;
success indicates whether the response status code should be 200 or 400
"""
if success:
status = 200
body = json.dumps({self.UID_FIELD: self.social_uid})
else:
status = 400
body = json.dumps({})
httpretty.register_uri(
httpretty.GET,
self.USER_URL,
body=body,
status=status,
content_type="application/json"
)
def _assert_error(self, _data, _expected_error, _expected_error_description):
"""
Given request data, execute a test and check that the expected error
was returned (along with any other appropriate assertions).
"""
raise NotImplementedError()
def _assert_success(self, data, expected_scopes):
"""
Given request data, execute a test and check that the expected scopes
were returned (along with any other appropriate assertions).
"""
raise NotImplementedError()
def test_minimal(self):
self._setup_provider_response(success=True)
self._assert_success(self.data, expected_scopes=[])
def test_scopes(self):
self._setup_provider_response(success=True)
self.data["scope"] = "profile email"
self._assert_success(self.data, expected_scopes=["profile", "email"])
def test_missing_fields(self):
for field in ["access_token", "client_id"]:
data = dict(self.data)
del data[field]
self._assert_error(data, "invalid_request", "{} is required".format(field))
def test_invalid_client(self):
self.data["client_id"] = "nonexistent_client"
self._assert_error(
self.data,
"invalid_client",
"nonexistent_client is not a valid client_id"
)
def test_confidential_client(self):
self.oauth_client.client_type = provider.constants.CONFIDENTIAL
self.oauth_client.save()
self._assert_error(
self.data,
"invalid_client",
"test_client_id is not a public client"
)
def test_invalid_acess_token(self):
self._setup_provider_response(success=False)
self._assert_error(self.data, "invalid_grant", "access_token is not valid")
def test_no_linked_user(self):
UserSocialAuth.objects.all().delete()
self._setup_provider_response(success=True)
self._assert_error(self.data, "invalid_grant", "access_token is not valid")
class AccessTokenExchangeMixinFacebook(object):
"""Tests access token exchange with the Facebook backend"""
BACKEND = "facebook"
USER_URL = "https://graph.facebook.com/me"
# In facebook responses, the "id" field is used as the user's identifier
UID_FIELD = "id"
class AccessTokenExchangeMixinGoogle(object):
"""Tests access token exchange with the Google backend"""
BACKEND = "google-oauth2"
USER_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
# In google-oauth2 responses, the "email" field is used as the user's identifier
UID_FIELD = "email"
"""
Views to support third-party to first-party OAuth 2.0 access token exchange
"""
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from provider import constants
from provider.oauth2.views import AccessTokenView as AccessTokenView
import social.apps.django_app.utils as social_utils
from oauth_exchange.forms import AccessTokenExchangeForm
class AccessTokenExchangeView(AccessTokenView):
"""View for access token exchange"""
@method_decorator(csrf_exempt)
@method_decorator(social_utils.strategy("social:complete"))
def dispatch(self, *args, **kwargs):
return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs)
def get(self, request, _backend):
return super(AccessTokenExchangeView, self).get(request)
def post(self, request, _backend):
form = AccessTokenExchangeForm(request=request, data=request.POST)
if not form.is_valid():
return self.error_response(form.errors)
user = form.cleaned_data["user"]
scope = form.cleaned_data["scope"]
client = form.cleaned_data["client"]
if constants.SINGLE_ACCESS_TOKEN:
edx_access_token = self.get_access_token(request, user, scope, client)
else:
edx_access_token = self.create_access_token(request, user, scope, client)
return self.access_token_response(edx_access_token)
...@@ -76,6 +76,8 @@ NAME_TO_EVENT_TYPE_MAP = { ...@@ -76,6 +76,8 @@ NAME_TO_EVENT_TYPE_MAP = {
'edx.video.paused': 'pause_video', 'edx.video.paused': 'pause_video',
'edx.video.stopped': 'stop_video', 'edx.video.stopped': 'stop_video',
'edx.video.loaded': 'load_video', 'edx.video.loaded': 'load_video',
'edx.video.position.changed': 'seek_video',
'edx.video.seeked': 'seek_video',
'edx.video.transcript.shown': 'show_transcript', 'edx.video.transcript.shown': 'show_transcript',
'edx.video.transcript.hidden': 'hide_transcript', 'edx.video.transcript.hidden': 'hide_transcript',
} }
...@@ -101,11 +103,14 @@ class VideoEventProcessor(object): ...@@ -101,11 +103,14 @@ class VideoEventProcessor(object):
if name not in NAME_TO_EVENT_TYPE_MAP: if name not in NAME_TO_EVENT_TYPE_MAP:
return return
# Convert edx.video.seeked to edx.video.positiion.changed
if name == "edx.video.seeked":
event['name'] = "edx.video.position.changed"
event['event_type'] = NAME_TO_EVENT_TYPE_MAP[name] event['event_type'] = NAME_TO_EVENT_TYPE_MAP[name]
if 'event' not in event: if 'event' not in event:
return return
payload = event['event'] payload = event['event']
if 'module_id' in payload: if 'module_id' in payload:
...@@ -122,13 +127,38 @@ class VideoEventProcessor(object): ...@@ -122,13 +127,38 @@ class VideoEventProcessor(object):
if 'current_time' in payload: if 'current_time' in payload:
payload['currentTime'] = payload.pop('current_time') payload['currentTime'] = payload.pop('current_time')
event['event'] = json.dumps(payload) if 'context' in event:
context = event['context']
if 'context' not in event:
return # Converts seek_type to seek and skip|slide to onSlideSeek|onSkipSeek
if 'seek_type' in payload:
context = event['context'] seek_type = payload['seek_type']
if seek_type == 'slide':
payload['type'] = "onSlideSeek"
elif seek_type == 'skip':
payload['type'] = "onSkipSeek"
del payload['seek_type']
# For the iOS build that is returning a +30 for back skip 30
if (
context['application']['version'] == "1.0.02" and
context['application']['name'] == "edx.mobileapp.iOS"
):
if 'requested_skip_interval' in payload and 'type' in payload:
if (
payload['requested_skip_interval'] == 30 and
payload['type'] == "onSkipSeek"
):
payload['requested_skip_interval'] = -30
# For the Android build that isn't distinguishing between skip/seek
if 'requested_skip_interval' in payload:
if abs(payload['requested_skip_interval']) != 30:
if 'type' in payload:
payload['type'] = 'onSlideSeek'
if 'open_in_browser_url' in context:
page, _sep, _tail = context.pop('open_in_browser_url').rpartition('/')
event['page'] = page
if 'open_in_browser_url' in context: event['event'] = json.dumps(payload)
page, _sep, _tail = context.pop('open_in_browser_url').rpartition('/')
event['page'] = page
...@@ -316,6 +316,7 @@ class SegmentIOTrackingTestCase(EventTrackingTestCase): ...@@ -316,6 +316,7 @@ class SegmentIOTrackingTestCase(EventTrackingTestCase):
('edx.video.paused', 'pause_video'), ('edx.video.paused', 'pause_video'),
('edx.video.stopped', 'stop_video'), ('edx.video.stopped', 'stop_video'),
('edx.video.loaded', 'load_video'), ('edx.video.loaded', 'load_video'),
('edx.video.position.changed', 'seek_video'),
('edx.video.transcript.shown', 'show_transcript'), ('edx.video.transcript.shown', 'show_transcript'),
('edx.video.transcript.hidden', 'hide_transcript'), ('edx.video.transcript.hidden', 'hide_transcript'),
) )
...@@ -404,3 +405,127 @@ class SegmentIOTrackingTestCase(EventTrackingTestCase): ...@@ -404,3 +405,127 @@ class SegmentIOTrackingTestCase(EventTrackingTestCase):
self.assertEqualUnicode(actual_event, expected_event_without_payload) self.assertEqualUnicode(actual_event, expected_event_without_payload)
self.assertEqualUnicode(payload, expected_payload) self.assertEqualUnicode(payload, expected_payload)
@data(
# Verify positive slide case. Verify slide to onSlideSeek. Verify edx.video.seeked emitted from iOS v1.0.02 is changed to edx.video.position.changed.
(1, 1, "seek_type", "slide", "onSlideSeek", "edx.video.seeked", "edx.video.position.changed", 'edx.mobileapp.iOS', '1.0.02'),
# Verify negative slide case. Verify slide to onSlideSeek. Verify edx.video.seeked to edx.video.position.changed.
(-2, -2, "seek_type", "slide", "onSlideSeek", "edx.video.seeked", "edx.video.position.changed", 'edx.mobileapp.iOS', '1.0.02'),
# Verify +30 is changed to -30 which is incorrectly emitted in iOS v1.0.02. Verify skip to onSkipSeek
(30, -30, "seek_type", "skip", "onSkipSeek", "edx.video.position.changed", "edx.video.position.changed", 'edx.mobileapp.iOS', '1.0.02'),
# Verify the correct case of -30 is also handled as well. Verify skip to onSkipSeek
(-30, -30, "seek_type", "skip", "onSkipSeek", "edx.video.position.changed", "edx.video.position.changed", 'edx.mobileapp.iOS', '1.0.02'),
# Verify positive slide case where onSkipSeek is changed to onSlideSkip. Verify edx.video.seeked emitted from Android v1.0.02 is changed to edx.video.position.changed.
(1, 1, "type", "onSkipSeek", "onSlideSeek", "edx.video.seeked", "edx.video.position.changed", 'edx.mobileapp.android', '1.0.02'),
# Verify positive slide case where onSkipSeek is changed to onSlideSkip. Verify edx.video.seeked emitted from Android v1.0.02 is changed to edx.video.position.changed.
(-2, -2, "type", "onSkipSeek", "onSlideSeek", "edx.video.seeked", "edx.video.position.changed", 'edx.mobileapp.android', '1.0.02'),
# Verify positive skip case where onSkipSeek is not changed and does not become negative.
(30, 30, "type", "onSkipSeek", "onSkipSeek", "edx.video.position.changed", "edx.video.position.changed", 'edx.mobileapp.android', '1.0.02'),
# Verify positive skip case where onSkipSeek is not changed.
(-30, -30, "type", "onSkipSeek", "onSkipSeek", "edx.video.position.changed", "edx.video.position.changed", 'edx.mobileapp.android', '1.0.02')
)
@unpack
def test_previous_builds(self,
requested_skip_interval,
expected_skip_interval,
seek_type_key,
seek_type,
expected_seek_type,
name,
expected_name,
platform,
version,
):
"""
Test backwards compatibility of previous app builds
iOS version 1.0.02: Incorrectly emits the skip back 30 seconds as +30
instead of -30.
Android version 1.0.02: Skip and slide were both being returned as a
skip. Skip or slide is determined by checking if the skip time is == -30
Additionally, for both of the above mentioned versions, edx.video.seeked
was sent instead of edx.video.position.changed
"""
course_id = 'foo/bar/baz'
middleware = TrackMiddleware()
input_payload = {
"code": "mobile",
"new_time": 89.699177437,
"old_time": 119.699177437,
seek_type_key: seek_type,
"requested_skip_interval": requested_skip_interval,
'module_id': 'i4x://foo/bar/baz/some_module',
}
request = self.create_request(
data=self.create_segmentio_event_json(
name=name,
data=input_payload,
context={
'open_in_browser_url': 'https://testserver/courses/foo/bar/baz/courseware/Week_1/Activity/2',
'course_id': course_id,
'application': {
'name': platform,
'version': version,
'component': 'videoplayer'
}
},
),
content_type='application/json'
)
User.objects.create(pk=USER_ID, username=str(sentinel.username))
middleware.process_request(request)
try:
response = segmentio.segmentio_event(request)
self.assertEquals(response.status_code, 200)
expected_event_without_payload = {
'accept_language': '',
'referer': '',
'username': str(sentinel.username),
'ip': '',
'session': '',
'event_source': 'mobile',
'event_type': "seek_video",
'name': expected_name,
'agent': str(sentinel.user_agent),
'page': 'https://testserver/courses/foo/bar/baz/courseware/Week_1/Activity',
'time': datetime.strptime("2014-08-27T16:33:39.215Z", "%Y-%m-%dT%H:%M:%S.%fZ"),
'host': 'testserver',
'context': {
'user_id': USER_ID,
'course_id': course_id,
'org_id': 'foo',
'path': ENDPOINT,
'client': {
'library': {
'name': 'test-app',
'version': 'unknown'
},
'app': {
'version': '1.0.1',
},
},
'application': {
'name': platform,
'version': version,
'component': 'videoplayer'
},
'received_at': datetime.strptime("2014-08-27T16:33:39.100Z", "%Y-%m-%dT%H:%M:%S.%fZ"),
},
}
expected_payload = {
"code": "mobile",
"new_time": 89.699177437,
"old_time": 119.699177437,
"type": expected_seek_type,
"requested_skip_interval": expected_skip_interval,
'id': 'i4x-foo-bar-baz-some_module',
}
finally:
middleware.process_response(request, None)
actual_event = dict(self.get_event())
payload = json.loads(actual_event.pop('event'))
self.assertEqualUnicode(actual_event, expected_event_without_payload)
self.assertEqualUnicode(payload, expected_payload)
...@@ -311,13 +311,15 @@ class DiscussionSortPreferencePage(CoursePage): ...@@ -311,13 +311,15 @@ class DiscussionSortPreferencePage(CoursePage):
class DiscussionTabSingleThreadPage(CoursePage): class DiscussionTabSingleThreadPage(CoursePage):
def __init__(self, browser, course_id, thread_id): def __init__(self, browser, course_id, discussion_id, thread_id):
super(DiscussionTabSingleThreadPage, self).__init__(browser, course_id) super(DiscussionTabSingleThreadPage, self).__init__(browser, course_id)
self.thread_page = DiscussionThreadPage( self.thread_page = DiscussionThreadPage(
browser, browser,
"body.discussion .discussion-article[data-id='{thread_id}']".format(thread_id=thread_id) "body.discussion .discussion-article[data-id='{thread_id}']".format(thread_id=thread_id)
) )
self.url_path = "discussion/forum/dummy/threads/" + thread_id self.url_path = "discussion/forum/{discussion_id}/threads/{thread_id}".format(
discussion_id=discussion_id, thread_id=thread_id
)
def is_browser_on_page(self): def is_browser_on_page(self):
return self.thread_page.is_browser_on_page() return self.thread_page.is_browser_on_page()
......
...@@ -5,12 +5,15 @@ Helper functions and classes for discussion tests. ...@@ -5,12 +5,15 @@ Helper functions and classes for discussion tests.
from uuid import uuid4 from uuid import uuid4
import json import json
from ...fixtures import LMS_BASE_URL
from ...fixtures.course import CourseFixture
from ...fixtures.discussion import ( from ...fixtures.discussion import (
SingleThreadViewFixture, SingleThreadViewFixture,
Thread, Thread,
Response, Response,
) )
from ...fixtures import LMS_BASE_URL from ...pages.lms.discussion import DiscussionTabSingleThreadPage
from ...tests.helpers import UniqueCourseTest
class BaseDiscussionMixin(object): class BaseDiscussionMixin(object):
...@@ -83,3 +86,22 @@ class CohortTestMixin(object): ...@@ -83,3 +86,22 @@ class CohortTestMixin(object):
data = {"users": username} data = {"users": username}
response = course_fixture.session.post(url, data=data, headers=course_fixture.headers) response = course_fixture.session.post(url, data=data, headers=course_fixture.headers)
self.assertTrue(response.ok, "Failed to add user to cohort") self.assertTrue(response.ok, "Failed to add user to cohort")
class BaseDiscussionTestCase(UniqueCourseTest):
def setUp(self):
super(BaseDiscussionTestCase, self).setUp()
self.discussion_id = "test_discussion_{}".format(uuid4().hex)
self.course_fixture = CourseFixture(**self.course_info)
self.course_fixture.add_advanced_settings(
{'discussion_topics': {'value': {'Test Discussion Topic': {'id': self.discussion_id}}}}
)
self.course_fixture.install()
def create_single_thread_page(self, thread_id):
"""
Sets up a `DiscussionTabSingleThreadPage` for a given
`thread_id`.
"""
return DiscussionTabSingleThreadPage(self.browser, self.course_id, self.discussion_id, thread_id)
...@@ -3,7 +3,7 @@ Tests related to the cohorting feature. ...@@ -3,7 +3,7 @@ Tests related to the cohorting feature.
""" """
from uuid import uuid4 from uuid import uuid4
from .helpers import BaseDiscussionMixin from .helpers import BaseDiscussionMixin, BaseDiscussionTestCase
from .helpers import CohortTestMixin from .helpers import CohortTestMixin
from ..helpers import UniqueCourseTest from ..helpers import UniqueCourseTest
from ...pages.lms.auto_auth import AutoAuthPage from ...pages.lms.auto_auth import AutoAuthPage
...@@ -57,20 +57,17 @@ class CohortedDiscussionTestMixin(BaseDiscussionMixin, CohortTestMixin): ...@@ -57,20 +57,17 @@ class CohortedDiscussionTestMixin(BaseDiscussionMixin, CohortTestMixin):
self.assertEquals(self.thread_page.get_group_visibility_label(), "This post is visible to everyone.") self.assertEquals(self.thread_page.get_group_visibility_label(), "This post is visible to everyone.")
class DiscussionTabSingleThreadTest(UniqueCourseTest): class DiscussionTabSingleThreadTest(BaseDiscussionTestCase):
""" """
Tests for the discussion page displaying a single thread. Tests for the discussion page displaying a single thread.
""" """
def setUp(self): def setUp(self):
super(DiscussionTabSingleThreadTest, self).setUp() super(DiscussionTabSingleThreadTest, self).setUp()
self.discussion_id = "test_discussion_{}".format(uuid4().hex)
# Create a course to register for
self.course_fixture = CourseFixture(**self.course_info).install()
self.setup_cohorts() self.setup_cohorts()
AutoAuthPage(self.browser, course_id=self.course_id).visit() AutoAuthPage(self.browser, course_id=self.course_id).visit()
def setup_thread_page(self, thread_id): def setup_thread_page(self, thread_id):
self.thread_page = DiscussionTabSingleThreadPage(self.browser, self.course_id, thread_id) # pylint: disable=attribute-defined-outside-init self.thread_page = DiscussionTabSingleThreadPage(self.browser, self.course_id, self.discussion_id, thread_id) # pylint: disable=attribute-defined-outside-init
self.thread_page.visit() self.thread_page.visit()
# pylint: disable=unused-argument # pylint: disable=unused-argument
......
...@@ -18,7 +18,8 @@ from xmodule.modulestore.xml import CourseLocationManager ...@@ -18,7 +18,8 @@ from xmodule.modulestore.xml import CourseLocationManager
from xmodule.tests import get_test_system from xmodule.tests import get_test_system
from courseware.tests.factories import GlobalStaffFactory, StaffFactory from courseware.tests.factories import GlobalStaffFactory, StaffFactory
from openedx.core.djangoapps.content.course_structures.models import CourseStructure, update_course_structure from openedx.core.djangoapps.content.course_structures.models import CourseStructure
from openedx.core.djangoapps.content.course_structures.tasks import update_course_structure
TEST_SERVER_HOST = 'http://testserver' TEST_SERVER_HOST = 'http://testserver'
......
...@@ -14,7 +14,7 @@ from opaque_keys.edx.keys import CourseKey ...@@ -14,7 +14,7 @@ from opaque_keys.edx.keys import CourseKey
from course_structure_api.v0 import serializers from course_structure_api.v0 import serializers
from courseware import courses from courseware import courses
from courseware.access import has_access from courseware.access import has_access
from openedx.core.djangoapps.content.course_structures import models from openedx.core.djangoapps.content.course_structures import models, tasks
from openedx.core.lib.api.permissions import IsAuthenticatedOrDebug from openedx.core.lib.api.permissions import IsAuthenticatedOrDebug
from openedx.core.lib.api.serializers import PaginationSerializer from openedx.core.lib.api.serializers import PaginationSerializer
from student.roles import CourseInstructorRole, CourseStaffRole from student.roles import CourseInstructorRole, CourseStaffRole
...@@ -191,7 +191,7 @@ class CourseStructure(CourseViewMixin, RetrieveAPIView): ...@@ -191,7 +191,7 @@ class CourseStructure(CourseViewMixin, RetrieveAPIView):
return super(CourseStructure, self).retrieve(request, *args, **kwargs) return super(CourseStructure, self).retrieve(request, *args, **kwargs)
except models.CourseStructure.DoesNotExist: except models.CourseStructure.DoesNotExist:
# If we don't have data stored, generate it and return a 503. # If we don't have data stored, generate it and return a 503.
models.update_course_structure.delay(unicode(self.course.id)) tasks.update_course_structure.delay(unicode(self.course.id))
return Response(status=503, headers={'Retry-After': '120'}) return Response(status=503, headers={'Retry-After': '120'})
def get_object(self, queryset=None): def get_object(self, queryset=None):
......
...@@ -14,7 +14,7 @@ from lms.lib.comment_client import Thread ...@@ -14,7 +14,7 @@ from lms.lib.comment_client import Thread
from xmodule.modulestore.tests.django_utils import TEST_DATA_MOCK_MODULESTORE from xmodule.modulestore.tests.django_utils import TEST_DATA_MOCK_MODULESTORE
from django_comment_client.base import views from django_comment_client.base import views
from django_comment_client.tests.group_id import CohortedTopicGroupIdTestMixin, NonCohortedTopicGroupIdTestMixin, GroupIdAssertionMixin from django_comment_client.tests.group_id import CohortedTopicGroupIdTestMixin, NonCohortedTopicGroupIdTestMixin, GroupIdAssertionMixin
from django_comment_client.tests.utils import CohortedContentTestCase from django_comment_client.tests.utils import CohortedTestCase
from django_comment_client.tests.unicode import UnicodeTestMixin from django_comment_client.tests.unicode import UnicodeTestMixin
from django_comment_common.models import Role from django_comment_common.models import Role
from django_comment_common.utils import seed_permissions_roles from django_comment_common.utils import seed_permissions_roles
...@@ -42,7 +42,7 @@ class MockRequestSetupMixin(object): ...@@ -42,7 +42,7 @@ class MockRequestSetupMixin(object):
@patch('lms.lib.comment_client.utils.requests.request') @patch('lms.lib.comment_client.utils.requests.request')
class CreateThreadGroupIdTestCase( class CreateThreadGroupIdTestCase(
MockRequestSetupMixin, MockRequestSetupMixin,
CohortedContentTestCase, CohortedTestCase,
CohortedTopicGroupIdTestMixin, CohortedTopicGroupIdTestMixin,
NonCohortedTopicGroupIdTestMixin NonCohortedTopicGroupIdTestMixin
): ):
...@@ -77,7 +77,7 @@ class CreateThreadGroupIdTestCase( ...@@ -77,7 +77,7 @@ class CreateThreadGroupIdTestCase(
@patch('lms.lib.comment_client.utils.requests.request') @patch('lms.lib.comment_client.utils.requests.request')
class ThreadActionGroupIdTestCase( class ThreadActionGroupIdTestCase(
MockRequestSetupMixin, MockRequestSetupMixin,
CohortedContentTestCase, CohortedTestCase,
GroupIdAssertionMixin GroupIdAssertionMixin
): ):
def call_view( def call_view(
......
...@@ -74,7 +74,7 @@ def track_forum_event(request, event_name, course, obj, data, id_map=None): ...@@ -74,7 +74,7 @@ def track_forum_event(request, event_name, course, obj, data, id_map=None):
user = request.user user = request.user
data['id'] = obj.id data['id'] = obj.id
if id_map is None: if id_map is None:
id_map = get_discussion_id_map(course) id_map = get_discussion_id_map(course, user)
commentable_id = data['commentable_id'] commentable_id = data['commentable_id']
if commentable_id in id_map: if commentable_id in id_map:
...@@ -173,9 +173,9 @@ def create_thread(request, course_id, commentable_id): ...@@ -173,9 +173,9 @@ def create_thread(request, course_id, commentable_id):
# Calls to id map are expensive, but we need this more than once. # Calls to id map are expensive, but we need this more than once.
# Prefetch it. # Prefetch it.
id_map = get_discussion_id_map(course) id_map = get_discussion_id_map(course, request.user)
add_courseware_context([data], course, id_map=id_map) add_courseware_context([data], course, request.user, id_map=id_map)
track_forum_event(request, 'edx.forum.thread.created', track_forum_event(request, 'edx.forum.thread.created',
course, thread, event_data, id_map=id_map) course, thread, event_data, id_map=id_map)
...@@ -208,7 +208,7 @@ def update_thread(request, course_id, thread_id): ...@@ -208,7 +208,7 @@ def update_thread(request, course_id, thread_id):
thread.thread_type = request.POST["thread_type"] thread.thread_type = request.POST["thread_type"]
if "commentable_id" in request.POST: if "commentable_id" in request.POST:
course = get_course_with_access(request.user, 'load', course_key) course = get_course_with_access(request.user, 'load', course_key)
commentable_ids = get_discussion_categories_ids(course) commentable_ids = get_discussion_categories_ids(course, request.user)
if request.POST.get("commentable_id") in commentable_ids: if request.POST.get("commentable_id") in commentable_ids:
thread.commentable_id = request.POST["commentable_id"] thread.commentable_id = request.POST["commentable_id"]
else: else:
......
...@@ -52,7 +52,7 @@ def _attr_safe_json(obj): ...@@ -52,7 +52,7 @@ def _attr_safe_json(obj):
@newrelic.agent.function_trace() @newrelic.agent.function_trace()
def make_course_settings(course): def make_course_settings(course, user):
""" """
Generate a JSON-serializable model for course settings, which will be used to initialize a Generate a JSON-serializable model for course settings, which will be used to initialize a
DiscussionCourseSettings object on the client. DiscussionCourseSettings object on the client.
...@@ -63,14 +63,14 @@ def make_course_settings(course): ...@@ -63,14 +63,14 @@ def make_course_settings(course):
'allow_anonymous': course.allow_anonymous, 'allow_anonymous': course.allow_anonymous,
'allow_anonymous_to_peers': course.allow_anonymous_to_peers, 'allow_anonymous_to_peers': course.allow_anonymous_to_peers,
'cohorts': [{"id": str(g.id), "name": g.name} for g in get_course_cohorts(course)], 'cohorts': [{"id": str(g.id), "name": g.name} for g in get_course_cohorts(course)],
'category_map': utils.get_discussion_category_map(course) 'category_map': utils.get_discussion_category_map(course, user)
} }
return obj return obj
@newrelic.agent.function_trace() @newrelic.agent.function_trace()
def get_threads(request, course_key, discussion_id=None, per_page=THREADS_PER_PAGE): def get_threads(request, course, discussion_id=None, per_page=THREADS_PER_PAGE):
""" """
This may raise an appropriate subclass of cc.utils.CommentClientError This may raise an appropriate subclass of cc.utils.CommentClientError
if something goes wrong, or ValueError if the group_id is invalid. if something goes wrong, or ValueError if the group_id is invalid.
...@@ -81,12 +81,16 @@ def get_threads(request, course_key, discussion_id=None, per_page=THREADS_PER_PA ...@@ -81,12 +81,16 @@ def get_threads(request, course_key, discussion_id=None, per_page=THREADS_PER_PA
'sort_key': 'date', 'sort_key': 'date',
'sort_order': 'desc', 'sort_order': 'desc',
'text': '', 'text': '',
'commentable_id': discussion_id, 'course_id': unicode(course.id),
'course_id': course_key.to_deprecated_string(),
'user_id': request.user.id, 'user_id': request.user.id,
'group_id': get_group_id_for_comments_service(request, course_key, discussion_id), # may raise ValueError 'group_id': get_group_id_for_comments_service(request, course.id, discussion_id), # may raise ValueError
} }
# If provided with a discussion id, filter by discussion id in the
# comments_service.
if discussion_id is not None:
default_query_params['commentable_id'] = discussion_id
if not request.GET.get('sort_key'): if not request.GET.get('sort_key'):
# If the user did not select a sort key, use their last used sort key # If the user did not select a sort key, use their last used sort key
cc_user = cc.User.from_django_user(request.user) cc_user = cc.User.from_django_user(request.user)
...@@ -124,6 +128,15 @@ def get_threads(request, course_key, discussion_id=None, per_page=THREADS_PER_PA ...@@ -124,6 +128,15 @@ def get_threads(request, course_key, discussion_id=None, per_page=THREADS_PER_PA
threads, page, num_pages, corrected_text = cc.Thread.search(query_params) threads, page, num_pages, corrected_text = cc.Thread.search(query_params)
# If not provided with a discussion id, filter threads by commentable ids
# which are accessible to the current user.
if discussion_id is None:
discussion_category_ids = set(utils.get_discussion_categories_ids(course, request.user))
threads = [
thread for thread in threads
if thread.get('commentable_id') in discussion_category_ids
]
for thread in threads: for thread in threads:
# patch for backward compatibility to comments service # patch for backward compatibility to comments service
if 'pinned' not in thread: if 'pinned' not in thread:
...@@ -163,7 +176,7 @@ def inline_discussion(request, course_key, discussion_id): ...@@ -163,7 +176,7 @@ def inline_discussion(request, course_key, discussion_id):
user_info = cc_user.to_dict() user_info = cc_user.to_dict()
try: try:
threads, query_params = get_threads(request, course_key, discussion_id, per_page=INLINE_THREADS_PER_PAGE) threads, query_params = get_threads(request, course, discussion_id, per_page=INLINE_THREADS_PER_PAGE)
except ValueError: except ValueError:
return HttpResponseBadRequest("Invalid group_id") return HttpResponseBadRequest("Invalid group_id")
...@@ -172,7 +185,7 @@ def inline_discussion(request, course_key, discussion_id): ...@@ -172,7 +185,7 @@ def inline_discussion(request, course_key, discussion_id):
is_staff = cached_has_permission(request.user, 'openclose_thread', course.id) is_staff = cached_has_permission(request.user, 'openclose_thread', course.id)
threads = [utils.prepare_content(thread, course_key, is_staff) for thread in threads] threads = [utils.prepare_content(thread, course_key, is_staff) for thread in threads]
with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"): with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"):
add_courseware_context(threads, course) add_courseware_context(threads, course, request.user)
return utils.JsonResponse({ return utils.JsonResponse({
'is_commentable_cohorted': is_commentable_cohorted(course_key, discussion_id), 'is_commentable_cohorted': is_commentable_cohorted(course_key, discussion_id),
'discussion_data': threads, 'discussion_data': threads,
...@@ -181,7 +194,7 @@ def inline_discussion(request, course_key, discussion_id): ...@@ -181,7 +194,7 @@ def inline_discussion(request, course_key, discussion_id):
'page': query_params['page'], 'page': query_params['page'],
'num_pages': query_params['num_pages'], 'num_pages': query_params['num_pages'],
'roles': utils.get_role_ids(course_key), 'roles': utils.get_role_ids(course_key),
'course_settings': make_course_settings(course) 'course_settings': make_course_settings(course, request.user)
}) })
...@@ -194,13 +207,13 @@ def forum_form_discussion(request, course_key): ...@@ -194,13 +207,13 @@ def forum_form_discussion(request, course_key):
nr_transaction = newrelic.agent.current_transaction() nr_transaction = newrelic.agent.current_transaction()
course = get_course_with_access(request.user, 'load_forum', course_key, check_if_enrolled=True) course = get_course_with_access(request.user, 'load_forum', course_key, check_if_enrolled=True)
course_settings = make_course_settings(course) course_settings = make_course_settings(course, request.user)
user = cc.User.from_django_user(request.user) user = cc.User.from_django_user(request.user)
user_info = user.to_dict() user_info = user.to_dict()
try: try:
unsafethreads, query_params = get_threads(request, course_key) # This might process a search query unsafethreads, query_params = get_threads(request, course) # This might process a search query
is_staff = cached_has_permission(request.user, 'openclose_thread', course.id) is_staff = cached_has_permission(request.user, 'openclose_thread', course.id)
threads = [utils.prepare_content(thread, course_key, is_staff) for thread in unsafethreads] threads = [utils.prepare_content(thread, course_key, is_staff) for thread in unsafethreads]
except cc.utils.CommentClientMaintenanceError: except cc.utils.CommentClientMaintenanceError:
...@@ -213,7 +226,7 @@ def forum_form_discussion(request, course_key): ...@@ -213,7 +226,7 @@ def forum_form_discussion(request, course_key):
annotated_content_info = utils.get_metadata_for_threads(course_key, threads, request.user, user_info) annotated_content_info = utils.get_metadata_for_threads(course_key, threads, request.user, user_info)
with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"): with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"):
add_courseware_context(threads, course) add_courseware_context(threads, course, request.user)
if request.is_ajax(): if request.is_ajax():
return utils.JsonResponse({ return utils.JsonResponse({
...@@ -258,15 +271,18 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -258,15 +271,18 @@ def single_thread(request, course_key, discussion_id, thread_id):
""" """
Renders a response to display a single discussion thread. Renders a response to display a single discussion thread.
""" """
nr_transaction = newrelic.agent.current_transaction() nr_transaction = newrelic.agent.current_transaction()
course = get_course_with_access(request.user, 'load_forum', course_key) course = get_course_with_access(request.user, 'load_forum', course_key)
course_settings = make_course_settings(course) course_settings = make_course_settings(course, request.user)
cc_user = cc.User.from_django_user(request.user) cc_user = cc.User.from_django_user(request.user)
user_info = cc_user.to_dict() user_info = cc_user.to_dict()
is_moderator = cached_has_permission(request.user, "see_all_cohorts", course_key) is_moderator = cached_has_permission(request.user, "see_all_cohorts", course_key)
# Verify that the student has access to this thread if belongs to a discussion module
if discussion_id not in utils.get_discussion_categories_ids(course, request.user):
raise Http404
# Currently, the front end always loads responses via AJAX, even for this # Currently, the front end always loads responses via AJAX, even for this
# page; it would be a nice optimization to avoid that extra round trip to # page; it would be a nice optimization to avoid that extra round trip to
# the comments service. # the comments service.
...@@ -294,7 +310,7 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -294,7 +310,7 @@ def single_thread(request, course_key, discussion_id, thread_id):
annotated_content_info = utils.get_annotated_content_infos(course_key, thread, request.user, user_info=user_info) annotated_content_info = utils.get_annotated_content_infos(course_key, thread, request.user, user_info=user_info)
content = utils.prepare_content(thread.to_dict(), course_key, is_staff) content = utils.prepare_content(thread.to_dict(), course_key, is_staff)
with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"): with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"):
add_courseware_context([content], course) add_courseware_context([content], course, request.user)
return utils.JsonResponse({ return utils.JsonResponse({
'content': content, 'content': content,
'annotated_content_info': annotated_content_info, 'annotated_content_info': annotated_content_info,
...@@ -302,13 +318,13 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -302,13 +318,13 @@ def single_thread(request, course_key, discussion_id, thread_id):
else: else:
try: try:
threads, query_params = get_threads(request, course_key) threads, query_params = get_threads(request, course)
except ValueError: except ValueError:
return HttpResponseBadRequest("Invalid group_id") return HttpResponseBadRequest("Invalid group_id")
threads.append(thread.to_dict()) threads.append(thread.to_dict())
with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"): with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"):
add_courseware_context(threads, course) add_courseware_context(threads, course, request.user)
for thread in threads: for thread in threads:
# patch for backward compatibility with comments service # patch for backward compatibility with comments service
......
"""
Utilities for tests within the django_comment_client module.
"""
from datetime import datetime
from mock import patch from mock import patch
from pytz import UTC
from openedx.core.djangoapps.course_groups.models import CourseUserGroup from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from django_comment_common.models import Role from django_comment_common.models import Role
from django_comment_common.utils import seed_permissions_roles from django_comment_common.utils import seed_permissions_roles
from student.tests.factories import CourseEnrollmentFactory, UserFactory from student.tests.factories import CourseEnrollmentFactory, UserFactory
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.partitions.partitions import UserPartition, Group
class CohortedContentTestCase(ModuleStoreTestCase): class CohortedTestCase(ModuleStoreTestCase):
""" """
Sets up a course with a student, a moderator and their cohorts. Sets up a course with a student, a moderator and their cohorts.
""" """
@patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) @patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self): def setUp(self):
super(CohortedContentTestCase, self).setUp() super(CohortedTestCase, self).setUp()
self.course = CourseFactory.create( self.course = CourseFactory.create(
discussion_topics={
"cohorted topic": {"id": "cohorted_topic"},
"non-cohorted topic": {"id": "non_cohorted_topic"},
},
cohort_config={ cohort_config={
"cohorted": True, "cohorted": True,
"cohorted_discussions": ["cohorted_topic"] "cohorted_discussions": ["cohorted_topic"]
} }
) )
self.student_cohort = CourseUserGroup.objects.create( self.student_cohort = CohortFactory.create(
name="student_cohort", name="student_cohort",
course_id=self.course.id, course_id=self.course.id
group_type=CourseUserGroup.COHORT
) )
self.moderator_cohort = CourseUserGroup.objects.create( self.moderator_cohort = CohortFactory.create(
name="moderator_cohort", name="moderator_cohort",
course_id=self.course.id, course_id=self.course.id
group_type=CourseUserGroup.COHORT
) )
self.course.discussion_topics["cohorted topic"] = {"id": "cohorted_topic"}
self.course.discussion_topics["non-cohorted topic"] = {"id": "non_cohorted_topic"}
self.store.update_item(self.course, self.user.id)
seed_permissions_roles(self.course.id) seed_permissions_roles(self.course.id)
self.student = UserFactory.create() self.student = UserFactory.create()
...@@ -45,3 +49,82 @@ class CohortedContentTestCase(ModuleStoreTestCase): ...@@ -45,3 +49,82 @@ class CohortedContentTestCase(ModuleStoreTestCase):
self.moderator.roles.add(Role.objects.get(name="Moderator", course_id=self.course.id)) self.moderator.roles.add(Role.objects.get(name="Moderator", course_id=self.course.id))
self.student_cohort.users.add(self.student) self.student_cohort.users.add(self.student)
self.moderator_cohort.users.add(self.moderator) self.moderator_cohort.users.add(self.moderator)
class ContentGroupTestCase(ModuleStoreTestCase):
"""
Sets up discussion modules visible to content groups 'Alpha' and
'Beta', as well as a module visible to all students. Creates a
staff user, users with access to Alpha/Beta (by way of cohorts),
and a non-cohorted user with no special access.
"""
def setUp(self):
super(ContentGroupTestCase, self).setUp()
self.course = CourseFactory.create(
org='org', number='number', run='run',
# This test needs to use a course that has already started --
# discussion topics only show up if the course has already started,
# and the default start date for courses is Jan 1, 2030.
start=datetime(2012, 2, 3, tzinfo=UTC),
user_partitions=[
UserPartition(
0,
'Content Group Configuration',
'',
[Group(1, 'Alpha'), Group(2, 'Beta')],
scheme_id='cohort'
)
],
cohort_config={'cohorted': True},
discussion_topics={}
)
self.staff_user = UserFactory.create(is_staff=True)
self.alpha_user = UserFactory.create()
self.beta_user = UserFactory.create()
self.non_cohorted_user = UserFactory.create()
for user in [self.staff_user, self.alpha_user, self.beta_user, self.non_cohorted_user]:
CourseEnrollmentFactory.create(user=user, course_id=self.course.id)
alpha_cohort = CohortFactory(
course_id=self.course.id,
name='Cohort Alpha',
users=[self.alpha_user]
)
beta_cohort = CohortFactory(
course_id=self.course.id,
name='Cohort Beta',
users=[self.beta_user]
)
CourseUserGroupPartitionGroup.objects.create(
course_user_group=alpha_cohort,
partition_id=self.course.user_partitions[0].id,
group_id=self.course.user_partitions[0].groups[0].id
)
CourseUserGroupPartitionGroup.objects.create(
course_user_group=beta_cohort,
partition_id=self.course.user_partitions[0].id,
group_id=self.course.user_partitions[0].groups[1].id
)
self.alpha_module = ItemFactory.create(
parent_location=self.course.location,
category='discussion',
discussion_id='alpha_group_discussion',
discussion_target='Visible to Alpha',
group_access={self.course.user_partitions[0].id: [self.course.user_partitions[0].groups[0].id]}
)
self.beta_module = ItemFactory.create(
parent_location=self.course.location,
category='discussion',
discussion_id='beta_group_discussion',
discussion_target='Visible to Beta',
group_access={self.course.user_partitions[0].id: [self.course.user_partitions[0].groups[1].id]}
)
self.global_module = ItemFactory.create(
parent_location=self.course.location,
category='discussion',
discussion_id='global_group_discussion',
discussion_target='Visible to Everyone'
)
self.course = self.store.get_item(self.course.location)
...@@ -18,6 +18,7 @@ from django_comment_common.models import Role, FORUM_ROLE_STUDENT ...@@ -18,6 +18,7 @@ from django_comment_common.models import Role, FORUM_ROLE_STUDENT
from django_comment_client.permissions import check_permissions_by_view, cached_has_permission from django_comment_client.permissions import check_permissions_by_view, cached_has_permission
from edxmako import lookup_template from edxmako import lookup_template
from courseware.access import has_access
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_by_id, get_cohort_id, is_commentable_cohorted, \ from openedx.core.djangoapps.course_groups.cohorts import get_cohort_by_id, get_cohort_id, is_commentable_cohorted, \
is_course_cohorted is_course_cohorted
from openedx.core.djangoapps.course_groups.models import CourseUserGroup from openedx.core.djangoapps.course_groups.models import CourseUserGroup
...@@ -59,9 +60,11 @@ def has_forum_access(uname, course_id, rolename): ...@@ -59,9 +60,11 @@ def has_forum_access(uname, course_id, rolename):
return role.users.filter(username=uname).exists() return role.users.filter(username=uname).exists()
def _get_discussion_modules(course): # pylint: disable=invalid-name
def get_accessible_discussion_modules(course, user):
""" """
Return a list of all valid discussion modules in this course. Return a list of all valid discussion modules in this course that
are accessible to the given user.
""" """
all_modules = modulestore().get_items(course.id, qualifiers={'category': 'discussion'}) all_modules = modulestore().get_items(course.id, qualifiers={'category': 'discussion'})
...@@ -72,12 +75,16 @@ def _get_discussion_modules(course): ...@@ -72,12 +75,16 @@ def _get_discussion_modules(course):
return False return False
return True return True
return filter(has_required_keys, all_modules) return [
module for module in all_modules
if has_required_keys(module) and has_access(user, 'load', module, course.id)
]
def get_discussion_id_map(course): def get_discussion_id_map(course, user):
""" """
Transform the list of this course's discussion modules into a dictionary of metadata keyed by discussion_id. Transform the list of this course's discussion modules (visible to a given user) into a dictionary of metadata keyed
by discussion_id.
""" """
def get_entry(module): # pylint: disable=missing-docstring def get_entry(module): # pylint: disable=missing-docstring
discussion_id = module.discussion_id discussion_id = module.discussion_id
...@@ -85,7 +92,7 @@ def get_discussion_id_map(course): ...@@ -85,7 +92,7 @@ def get_discussion_id_map(course):
last_category = module.discussion_category.split("/")[-1].strip() last_category = module.discussion_category.split("/")[-1].strip()
return (discussion_id, {"location": module.location, "title": last_category + " / " + title}) return (discussion_id, {"location": module.location, "title": last_category + " / " + title})
return dict(map(get_entry, _get_discussion_modules(course))) return dict(map(get_entry, get_accessible_discussion_modules(course, user)))
def _filter_unstarted_categories(category_map): def _filter_unstarted_categories(category_map):
...@@ -138,14 +145,14 @@ def _sort_map_entries(category_map, sort_alpha): ...@@ -138,14 +145,14 @@ def _sort_map_entries(category_map, sort_alpha):
category_map["children"] = [x[0] for x in sorted(things, key=lambda x: x[1]["sort_key"])] category_map["children"] = [x[0] for x in sorted(things, key=lambda x: x[1]["sort_key"])]
def get_discussion_category_map(course): def get_discussion_category_map(course, user):
""" """
Transform the list of this course's discussion modules into a recursive dictionary structure. This is used Transform the list of this course's discussion modules into a recursive dictionary structure. This is used
to render the discussion category map in the discussion tab sidebar. to render the discussion category map in the discussion tab sidebar for a given user.
""" """
unexpanded_category_map = defaultdict(list) unexpanded_category_map = defaultdict(list)
modules = _get_discussion_modules(course) modules = get_accessible_discussion_modules(course, user)
is_course_cohorted = course.is_cohorted is_course_cohorted = course.is_cohorted
cohorted_discussion_ids = course.cohorted_discussions cohorted_discussion_ids = course.cohorted_discussions
...@@ -218,21 +225,15 @@ def get_discussion_category_map(course): ...@@ -218,21 +225,15 @@ def get_discussion_category_map(course):
return _filter_unstarted_categories(category_map) return _filter_unstarted_categories(category_map)
def get_discussion_categories_ids(course): def get_discussion_categories_ids(course, user):
""" """
Returns a list of available ids of categories for the course. Returns a list of available ids of categories for the course that
are accessible to the given user.
""" """
ids = [] accessible_discussion_ids = [
queue = [get_discussion_category_map(course)] module.discussion_id for module in get_accessible_discussion_modules(course, user)
while queue: ]
category_map = queue.pop() return course.top_level_discussion_topic_ids + accessible_discussion_ids
for child in category_map["children"]:
if child in category_map["entries"]:
ids.append(category_map["entries"][child]["id"])
else:
queue.append(category_map["subcategories"][child])
return ids
class JsonResponse(HttpResponse): class JsonResponse(HttpResponse):
...@@ -382,12 +383,12 @@ def extend_content(content): ...@@ -382,12 +383,12 @@ def extend_content(content):
return merge_dict(content, content_info) return merge_dict(content, content_info)
def add_courseware_context(content_list, course, id_map=None): def add_courseware_context(content_list, course, user, id_map=None):
""" """
Decorates `content_list` with courseware metadata. Decorates `content_list` with courseware metadata.
""" """
if id_map is None: if id_map is None:
id_map = get_discussion_id_map(course) id_map = get_discussion_id_map(course, user)
for content in content_list: for content in content_list:
commentable_id = content['commentable_id'] commentable_id = content['commentable_id']
......
...@@ -1560,6 +1560,8 @@ INSTALLED_APPS = ( ...@@ -1560,6 +1560,8 @@ INSTALLED_APPS = (
'provider.oauth2', 'provider.oauth2',
'oauth2_provider', 'oauth2_provider',
'oauth_exchange',
# For the wiki # For the wiki
'wiki', # The new django-wiki from benjaoming 'wiki', # The new django-wiki from benjaoming
'django_notify', 'django_notify',
......
...@@ -5,6 +5,7 @@ from django.conf.urls.static import static ...@@ -5,6 +5,7 @@ from django.conf.urls.static import static
import django.contrib.auth.views import django.contrib.auth.views
from microsite_configuration import microsite from microsite_configuration import microsite
import oauth_exchange.views
# Uncomment the next two lines to enable the admin: # Uncomment the next two lines to enable the admin:
if settings.DEBUG or settings.FEATURES.get('ENABLE_DJANGO_ADMIN_SITE'): if settings.DEBUG or settings.FEATURES.get('ENABLE_DJANGO_ADMIN_SITE'):
...@@ -588,6 +589,11 @@ if settings.FEATURES.get('AUTOMATIC_AUTH_FOR_TESTING'): ...@@ -588,6 +589,11 @@ if settings.FEATURES.get('AUTOMATIC_AUTH_FOR_TESTING'):
if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
urlpatterns += ( urlpatterns += (
url(r'', include('third_party_auth.urls')), url(r'', include('third_party_auth.urls')),
url(
r'^oauth2/exchange_access_token/(?P<backend>[^/]+)/$',
oauth_exchange.views.AccessTokenExchangeView.as_view(),
name="exchange_access_token"
),
url(r'^login_oauth_token/(?P<backend>[^/]+)/$', 'student.views.login_oauth_token'), url(r'^login_oauth_token/(?P<backend>[^/]+)/$', 'student.views.login_oauth_token'),
) )
......
...@@ -2,14 +2,13 @@ import logging ...@@ -2,14 +2,13 @@ import logging
from optparse import make_option from optparse import make_option
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from openedx.core.djangoapps.content.course_structures.models import update_course_structure from openedx.core.djangoapps.content.course_structures.tasks import update_course_structure
logger = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Command(BaseCommand): class Command(BaseCommand):
...@@ -31,16 +30,21 @@ class Command(BaseCommand): ...@@ -31,16 +30,21 @@ class Command(BaseCommand):
course_keys = [CourseKey.from_string(arg) for arg in args] course_keys = [CourseKey.from_string(arg) for arg in args]
if not course_keys: if not course_keys:
logger.fatal('No courses specified.') log.fatal('No courses specified.')
return return
logger.info('Generating course structures for %d courses.', len(course_keys)) log.info('Generating course structures for %d courses.', len(course_keys))
logging.debug('Generating course structure(s) for the following courses: %s', course_keys) log.debug('Generating course structure(s) for the following courses: %s', course_keys)
for course_key in course_keys: for course_key in course_keys:
try: try:
update_course_structure(unicode(course_key)) # Run the update task synchronously so that we know when all course structures have been updated.
except Exception as e: # TODO Future improvement: Use .delay(), add return value to ResultSet, and wait for execution of
logger.error('An error occurred while generating course structure for %s: %s', unicode(course_key), e) # all tasks using ResultSet.join(). I (clintonb) am opting not to make this improvement right now
# as I do not have time to test it fully.
logger.info('Finished generating course structures.') update_course_structure.apply(unicode(course_key))
except Exception as ex:
log.exception('An error occurred while generating course structure for %s: %s',
unicode(course_key), ex.message)
log.info('Finished generating course structures.')
import json import json
import logging import logging
from celery.task import task
from django.dispatch import receiver
from model_utils.models import TimeStampedModel from model_utils.models import TimeStampedModel
from opaque_keys.edx.keys import CourseKey
from xmodule.modulestore.django import modulestore, SignalHandler
from util.models import CompressedTextField from util.models import CompressedTextField
from xmodule_django.models import CourseKeyField from xmodule_django.models import CourseKeyField
...@@ -30,66 +26,6 @@ class CourseStructure(TimeStampedModel): ...@@ -30,66 +26,6 @@ class CourseStructure(TimeStampedModel):
return json.loads(self.structure_json) return json.loads(self.structure_json)
return None return None
# Signals must be imported in a file that is automatically loaded at app startup (e.g. models.py). We import them
def generate_course_structure(course_key): # at the end of this file to avoid circular dependencies.
""" import signals # pylint: disable=unused-import
Generates a course structure dictionary for the specified course.
"""
course = modulestore().get_course(course_key, depth=None)
blocks_stack = [course]
blocks_dict = {}
while blocks_stack:
curr_block = blocks_stack.pop()
children = curr_block.get_children() if curr_block.has_children else []
blocks_dict[unicode(curr_block.scope_ids.usage_id)] = {
"usage_key": unicode(curr_block.scope_ids.usage_id),
"block_type": curr_block.category,
"display_name": curr_block.display_name,
"graded": curr_block.graded,
"format": curr_block.format,
"children": [unicode(child.scope_ids.usage_id) for child in children]
}
blocks_stack.extend(children)
return {
"root": unicode(course.scope_ids.usage_id),
"blocks": blocks_dict
}
@receiver(SignalHandler.course_published)
def listen_for_course_publish(sender, course_key, **kwargs):
# Note: The countdown=0 kwarg is set to to ensure the method below does not attempt to access the course
# before the signal emitter has finished all operations. This is also necessary to ensure all tests pass.
update_course_structure.delay(unicode(course_key), countdown=0)
@task(name=u'openedx.core.djangoapps.content.course_structures.models.update_course_structure')
def update_course_structure(course_key):
"""
Regenerates and updates the course structure (in the database) for the specified course.
"""
# Ideally we'd like to accept a CourseLocator; however, CourseLocator is not JSON-serializable (by default) so
# Celery's delayed tasks fail to start. For this reason, callers should pass the course key as a Unicode string.
if not isinstance(course_key, basestring):
raise ValueError('course_key must be a string. {} is not acceptable.'.format(type(course_key)))
course_key = CourseKey.from_string(course_key)
try:
structure = generate_course_structure(course_key)
except Exception as e:
logger.error('An error occurred while generating course structure: %s', e)
raise
structure_json = json.dumps(structure)
cs, created = CourseStructure.objects.get_or_create(
course_id=course_key,
defaults={'structure_json': structure_json}
)
if not created:
cs.structure_json = structure_json
cs.save()
return cs
from django.dispatch.dispatcher import receiver
from xmodule.modulestore.django import SignalHandler
@receiver(SignalHandler.course_published)
def listen_for_course_publish(sender, course_key, **kwargs): # pylint: disable=unused-argument
# Import tasks here to avoid a circular import.
from .tasks import update_course_structure
# Note: The countdown=0 kwarg is set to to ensure the method below does not attempt to access the course
# before the signal emitter has finished all operations. This is also necessary to ensure all tests pass.
update_course_structure.delay(unicode(course_key), countdown=0)
import json
import logging
from celery.task import task
from opaque_keys.edx.keys import CourseKey
from xmodule.modulestore.django import modulestore
log = logging.getLogger('edx.celery.task')
def _generate_course_structure(course_key):
"""
Generates a course structure dictionary for the specified course.
"""
course = modulestore().get_course(course_key, depth=None)
blocks_stack = [course]
blocks_dict = {}
while blocks_stack:
curr_block = blocks_stack.pop()
children = curr_block.get_children() if curr_block.has_children else []
key = unicode(curr_block.scope_ids.usage_id)
block = {
"usage_key": key,
"block_type": curr_block.category,
"display_name": curr_block.display_name,
"children": [unicode(child.scope_ids.usage_id) for child in children]
}
# Retrieve these attributes separately so that we can fail gracefully if the block doesn't have the attribute.
attrs = (('graded', False), ('format', None))
for attr, default in attrs:
if hasattr(curr_block, attr):
block[attr] = getattr(curr_block, attr, default)
else:
log.warning('Failed to retrieve %s attribute of block %s. Defaulting to %s.', attr, key, default)
block[attr] = default
blocks_dict[key] = block
# Add this blocks children to the stack so that we can traverse them as well.
blocks_stack.extend(children)
return {
"root": unicode(course.scope_ids.usage_id),
"blocks": blocks_dict
}
@task(name=u'openedx.core.djangoapps.content.course_structures.tasks.update_course_structure')
def update_course_structure(course_key):
"""
Regenerates and updates the course structure (in the database) for the specified course.
"""
# Import here to avoid circular import.
from .models import CourseStructure
# Ideally we'd like to accept a CourseLocator; however, CourseLocator is not JSON-serializable (by default) so
# Celery's delayed tasks fail to start. For this reason, callers should pass the course key as a Unicode string.
if not isinstance(course_key, basestring):
raise ValueError('course_key must be a string. {} is not acceptable.'.format(type(course_key)))
course_key = CourseKey.from_string(course_key)
try:
structure = _generate_course_structure(course_key)
except Exception as ex:
log.exception('An error occurred while generating course structure: %s', ex.message)
raise
structure_json = json.dumps(structure)
cs, created = CourseStructure.objects.get_or_create(
course_id=course_key,
defaults={'structure_json': structure_json}
)
if not created:
cs.structure_json = structure_json
cs.save()
import json import json
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory
from openedx.core.djangoapps.content.course_structures.models import generate_course_structure, CourseStructure from openedx.core.djangoapps.content.course_structures.models import CourseStructure
from openedx.core.djangoapps.content.course_structures.tasks import _generate_course_structure, update_course_structure
class CourseStructureTests(ModuleStoreTestCase): class CourseStructureTaskTests(ModuleStoreTestCase):
def setUp(self, **kwargs): def setUp(self, **kwargs):
super(CourseStructureTests, self).setUp() super(CourseStructureTaskTests, self).setUp()
self.course = CourseFactory.create() self.course = CourseFactory.create()
self.section = ItemFactory.create(parent=self.course, category='chapter', display_name='Test Section') self.section = ItemFactory.create(parent=self.course, category='chapter', display_name='Test Section')
CourseStructure.objects.all().delete() CourseStructure.objects.all().delete()
...@@ -38,7 +40,7 @@ class CourseStructureTests(ModuleStoreTestCase): ...@@ -38,7 +40,7 @@ class CourseStructureTests(ModuleStoreTestCase):
} }
self.maxDiff = None self.maxDiff = None
actual = generate_course_structure(self.course.id) actual = _generate_course_structure(self.course.id)
self.assertDictEqual(actual, expected) self.assertDictEqual(actual, expected)
def test_structure_json(self): def test_structure_json(self):
...@@ -77,3 +79,41 @@ class CourseStructureTests(ModuleStoreTestCase): ...@@ -77,3 +79,41 @@ class CourseStructureTests(ModuleStoreTestCase):
structure_json = json.dumps(structure) structure_json = json.dumps(structure)
cs = CourseStructure.objects.create(course_id=self.course.id, structure_json=structure_json) cs = CourseStructure.objects.create(course_id=self.course.id, structure_json=structure_json)
self.assertDictEqual(cs.structure, structure) self.assertDictEqual(cs.structure, structure)
def test_block_with_missing_fields(self):
"""
The generator should continue to operate on blocks/XModule that do not have graded or format fields.
"""
# TODO In the future, test logging using testfixtures.LogCapture
# (https://pythonhosted.org/testfixtures/logging.html). Talk to TestEng before adding that library.
category = 'peergrading'
display_name = 'Testing Module'
module = ItemFactory.create(parent=self.section, category=category, display_name=display_name)
structure = _generate_course_structure(self.course.id)
usage_key = unicode(module.location)
actual = structure['blocks'][usage_key]
expected = {
"usage_key": usage_key,
"block_type": category,
"display_name": display_name,
"graded": False,
"format": None,
"children": []
}
self.assertEqual(actual, expected)
def test_update_course_structure(self):
"""
Test the actual task that orchestrates data generation and updating the database.
"""
# Method requires string input
course_id = self.course.id
self.assertRaises(ValueError, update_course_structure, course_id)
# Ensure a CourseStructure object is created
structure = _generate_course_structure(course_id)
update_course_structure(unicode(course_id))
cs = CourseStructure.objects.get(course_id=course_id)
self.assertEqual(cs.course_id, course_id)
self.assertEqual(cs.structure, structure)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment