Commit d4817685 by Renzo Lucioni

Merge pull request #11691 from edx/jeskew/shared_modulestore_test_case_conversion

WIP: Lots of Python unittest speedups
parents be7be407 b02d1c12
"""
Script for converting a tar.gz file representing an exported course
to the archive format used by a different version of export.
Sample invocation: ./manage.py export_convert_format mycourse.tar.gz ~/newformat/
"""
import os
from path import Path as path
from django.core.management.base import BaseCommand, CommandError
from django.conf import settings
from tempfile import mkdtemp
import tarfile
import shutil
from openedx.core.lib.extract_tar import safetar_extractall
from xmodule.modulestore.xml_exporter import convert_between_versions
class Command(BaseCommand):
"""
Convert between export formats.
"""
help = 'Convert between versions 0 and 1 of the course export format'
args = '<tar.gz archive file> <output path>'
def handle(self, *args, **options):
"Execute the command"
if len(args) != 2:
raise CommandError("export requires two arguments: <tar.gz file> <output path>")
source_archive = args[0]
output_path = args[1]
# Create temp directories to extract the source and create the target archive.
temp_source_dir = mkdtemp(dir=settings.DATA_DIR)
temp_target_dir = mkdtemp(dir=settings.DATA_DIR)
try:
extract_source(source_archive, temp_source_dir)
desired_version = convert_between_versions(temp_source_dir, temp_target_dir)
# New zip up the target directory.
parts = os.path.basename(source_archive).split('.')
archive_name = path(output_path) / "{source_name}_version_{desired_version}.tar.gz".format(
source_name=parts[0], desired_version=desired_version
)
with open(archive_name, "w"):
tar_file = tarfile.open(archive_name, mode='w:gz')
try:
for item in os.listdir(temp_target_dir):
tar_file.add(path(temp_target_dir) / item, arcname=item)
finally:
tar_file.close()
print "Created archive {0}".format(archive_name)
except ValueError as err:
raise CommandError(err)
finally:
shutil.rmtree(temp_source_dir)
shutil.rmtree(temp_target_dir)
def extract_source(source_archive, target):
"""
Extract the archive into the given target directory.
"""
with tarfile.open(source_archive) as tar_file:
safetar_extractall(tar_file, target)
"""
Test for export_convert_format.
"""
from unittest import TestCase
from django.core.management import call_command, CommandError
from django.conf import settings
from tempfile import mkdtemp
import shutil
from path import Path as path
from contentstore.management.commands.export_convert_format import Command, extract_source
from xmodule.tests.helpers import directories_equal
class ConvertExportFormat(TestCase):
"""
Tests converting between export formats.
"""
def setUp(self):
""" Common setup. """
super(ConvertExportFormat, self).setUp()
self.temp_dir = mkdtemp(dir=settings.DATA_DIR)
self.addCleanup(shutil.rmtree, self.temp_dir)
self.data_dir = path(__file__).realpath().parent / 'data'
self.version0 = self.data_dir / "Version0_drafts.tar.gz"
self.version1 = self.data_dir / "Version1_drafts.tar.gz"
self.command = Command()
def test_no_args(self):
""" Test error condition of no arguments. """
errstring = "export requires two arguments"
with self.assertRaisesRegexp(CommandError, errstring):
self.command.handle()
def test_version1_archive(self):
"""
Smoke test for creating a version 1 archive from a version 0.
"""
call_command('export_convert_format', self.version0, self.temp_dir)
output = path(self.temp_dir) / 'Version0_drafts_version_1.tar.gz'
self.assertTrue(self._verify_archive_equality(output, self.version1))
def test_version0_archive(self):
"""
Smoke test for creating a version 0 archive from a version 1.
"""
call_command('export_convert_format', self.version1, self.temp_dir)
output = path(self.temp_dir) / 'Version1_drafts_version_0.tar.gz'
self.assertTrue(self._verify_archive_equality(output, self.version0))
def _verify_archive_equality(self, file1, file2):
"""
Helper function for determining if 2 archives are equal.
"""
temp_dir_1 = mkdtemp(dir=settings.DATA_DIR)
temp_dir_2 = mkdtemp(dir=settings.DATA_DIR)
try:
extract_source(file1, temp_dir_1)
extract_source(file2, temp_dir_2)
return directories_equal(temp_dir_1, temp_dir_2)
finally:
shutil.rmtree(temp_dir_1)
shutil.rmtree(temp_dir_2)
...@@ -29,8 +29,9 @@ from opaque_keys.edx.locations import CourseLocator ...@@ -29,8 +29,9 @@ from opaque_keys.edx.locations import CourseLocator
from xmodule.error_module import ErrorDescriptor from xmodule.error_module import ErrorDescriptor
from course_action_state.models import CourseRerunState from course_action_state.models import CourseRerunState
TOTAL_COURSES_COUNT = 500
USER_COURSES_COUNT = 50 TOTAL_COURSES_COUNT = 10
USER_COURSES_COUNT = 1
@ddt.ddt @ddt.ddt
...@@ -157,8 +158,8 @@ class TestCourseListing(ModuleStoreTestCase, XssTestMixin): ...@@ -157,8 +158,8 @@ class TestCourseListing(ModuleStoreTestCase, XssTestMixin):
self.assertEqual(courses_list_by_groups, []) self.assertEqual(courses_list_by_groups, [])
@ddt.data( @ddt.data(
(ModuleStoreEnum.Type.split, 5), (ModuleStoreEnum.Type.split, 3),
(ModuleStoreEnum.Type.mongo, 3) (ModuleStoreEnum.Type.mongo, 2)
) )
@ddt.unpack @ddt.unpack
def test_staff_course_listing(self, default_store, mongo_calls): def test_staff_course_listing(self, default_store, mongo_calls):
...@@ -265,8 +266,8 @@ class TestCourseListing(ModuleStoreTestCase, XssTestMixin): ...@@ -265,8 +266,8 @@ class TestCourseListing(ModuleStoreTestCase, XssTestMixin):
) )
@ddt.data( @ddt.data(
(ModuleStoreEnum.Type.split, 150, 505), (ModuleStoreEnum.Type.split, 3, 13),
(ModuleStoreEnum.Type.mongo, USER_COURSES_COUNT, 3) (ModuleStoreEnum.Type.mongo, USER_COURSES_COUNT, 2)
) )
@ddt.unpack @ddt.unpack
def test_course_listing_performance(self, store, courses_list_from_group_calls, courses_list_calls): def test_course_listing_performance(self, store, courses_list_from_group_calls, courses_list_calls):
......
...@@ -9,7 +9,7 @@ settings.INSTALLED_APPS # pylint: disable=pointless-statement ...@@ -9,7 +9,7 @@ settings.INSTALLED_APPS # pylint: disable=pointless-statement
from openedx.core.lib.django_startup import autostartup from openedx.core.lib.django_startup import autostartup
import django import django
from monkey_patch import third_party_auth from monkey_patch import third_party_auth, django_db_models_options
import xmodule.x_module import xmodule.x_module
import cms.lib.xblock.runtime import cms.lib.xblock.runtime
...@@ -22,6 +22,7 @@ def run(): ...@@ -22,6 +22,7 @@ def run():
Executed during django startup Executed during django startup
""" """
third_party_auth.patch() third_party_auth.patch()
django_db_models_options.patch()
# Comprehensive theming needs to be set up before django startup, # Comprehensive theming needs to be set up before django startup,
# because modifying django template paths after startup has no effect. # because modifying django template paths after startup has no effect.
......
"""
Monkey patch implementation of the following _expire_cache performance improvement:
https://github.com/django/django/commit/7628f87e2b1ab4b8a881f06c8973be4c368aaa3d
Remove once we upgrade to a version of django which includes this fix natively!
NOTE: This is on django's master branch but is NOT currently part of any django 1.8 or 1.9 release.
"""
from django.db.models.options import Options
def patch():
"""
Monkey-patch the Options class.
"""
def _expire_cache(self, forward=True, reverse=True):
# pylint: disable=missing-docstring
# This method is usually called by apps.cache_clear(), when the
# registry is finalized, or when a new field is added.
if forward:
for cache_key in self.FORWARD_PROPERTIES:
if cache_key in self.__dict__:
delattr(self, cache_key)
if reverse and not self.abstract:
for cache_key in self.REVERSE_PROPERTIES:
if cache_key in self.__dict__:
delattr(self, cache_key)
self._get_fields_cache = {} # pylint: disable=protected-access
# Patch constants as a set instead of a list.
Options.FORWARD_PROPERTIES = {'fields', 'many_to_many', 'concrete_fields',
'local_concrete_fields', '_forward_fields_map'}
Options.REVERSE_PROPERTIES = {'related_objects', 'fields_map', '_relation_tree'}
# Patch the expire_cache method to utilize constant's new set data structure.
Options._expire_cache = _expire_cache # pylint: disable=protected-access
...@@ -4,6 +4,7 @@ Modulestore configuration for test cases. ...@@ -4,6 +4,7 @@ Modulestore configuration for test cases.
""" """
import functools import functools
from uuid import uuid4 from uuid import uuid4
from contextlib import contextmanager
from mock import patch from mock import patch
...@@ -269,9 +270,10 @@ class SharedModuleStoreTestCase(TestCase): ...@@ -269,9 +270,10 @@ class SharedModuleStoreTestCase(TestCase):
multi_db = True multi_db = True
@classmethod @classmethod
def setUpClass(cls): def _setUpModuleStore(cls): # pylint: disable=invalid-name
super(SharedModuleStoreTestCase, cls).setUpClass() """
Set up the modulestore for an entire test class.
"""
cls._settings_override = override_settings(MODULESTORE=cls.MODULESTORE) cls._settings_override = override_settings(MODULESTORE=cls.MODULESTORE)
cls._settings_override.__enter__() cls._settings_override.__enter__()
XMODULE_FACTORY_LOCK.enable() XMODULE_FACTORY_LOCK.enable()
...@@ -279,6 +281,40 @@ class SharedModuleStoreTestCase(TestCase): ...@@ -279,6 +281,40 @@ class SharedModuleStoreTestCase(TestCase):
cls.store = modulestore() cls.store = modulestore()
@classmethod @classmethod
@contextmanager
def setUpClassAndTestData(cls): # pylint: disable=invalid-name
"""
For use when the test class has a setUpTestData() method that uses variables
that are setup during setUpClass() of the same test class.
Use it like so:
@classmethod
def setUpClass(cls):
with super(MyTestClass, cls).setUpClassAndTestData():
<all the cls.setUpClass() setup code that performs modulestore setup...>
@classmethod
def setUpTestData(cls):
<all the setup code that creates Django models per test class...>
<these models can use variables (courses) setup in setUpClass() above>
"""
cls._setUpModuleStore()
# Now yield to allow the test class to run its setUpClass() setup code.
yield
# Now call the base class, which calls back into the test class's setUpTestData().
super(SharedModuleStoreTestCase, cls).setUpClass()
@classmethod
def setUpClass(cls):
"""
For use when the test class has no setUpTestData() method -or-
when that method does not use variable set up in setUpClass().
"""
super(SharedModuleStoreTestCase, cls).setUpClass()
cls._setUpModuleStore()
@classmethod
def tearDownClass(cls): def tearDownClass(cls):
drop_mongo_collections() # pylint: disable=no-value-for-parameter drop_mongo_collections() # pylint: disable=no-value-for-parameter
clear_all_caches() clear_all_caches()
......
...@@ -24,8 +24,6 @@ from opaque_keys.edx.locator import CourseLocator, LibraryLocator ...@@ -24,8 +24,6 @@ from opaque_keys.edx.locator import CourseLocator, LibraryLocator
DRAFT_DIR = "drafts" DRAFT_DIR = "drafts"
PUBLISHED_DIR = "published" PUBLISHED_DIR = "published"
EXPORT_VERSION_FILE = "format.json"
EXPORT_VERSION_KEY = "export_format"
DEFAULT_CONTENT_FIELDS = ['metadata', 'data'] DEFAULT_CONTENT_FIELDS = ['metadata', 'data']
...@@ -408,90 +406,3 @@ def export_extra_content(export_fs, modulestore, source_course_key, dest_course_ ...@@ -408,90 +406,3 @@ def export_extra_content(export_fs, modulestore, source_course_key, dest_course_
# export content fields other then metadata and data in json format in current directory # export content fields other then metadata and data in json format in current directory
_export_field_content(item, item_dir) _export_field_content(item, item_dir)
def convert_between_versions(source_dir, target_dir):
"""
Converts a version 0 export format to version 1, and vice versa.
@param source_dir: the directory structure with the course export that should be converted.
The contents of source_dir will not be altered.
@param target_dir: the directory where the converted export should be written.
@return: the version number of the converted export.
"""
def convert_to_version_1():
""" Convert a version 0 archive to version 0 """
os.mkdir(copy_root)
with open(copy_root / EXPORT_VERSION_FILE, 'w') as f:
f.write('{{"{export_key}": 1}}\n'.format(export_key=EXPORT_VERSION_KEY))
# If a drafts folder exists, copy it over.
copy_drafts()
# Now copy everything into the published directory
published_dir = copy_root / PUBLISHED_DIR
shutil.copytree(path(source_dir) / course_name, published_dir)
# And delete the nested drafts directory, if it exists.
nested_drafts_dir = published_dir / DRAFT_DIR
if nested_drafts_dir.isdir():
shutil.rmtree(nested_drafts_dir)
def convert_to_version_0():
""" Convert a version 1 archive to version 0 """
# Copy everything in "published" up to the top level.
published_dir = path(source_dir) / course_name / PUBLISHED_DIR
if not published_dir.isdir():
raise ValueError("a version 1 archive must contain a published branch")
shutil.copytree(published_dir, copy_root)
# If there is a DRAFT branch, copy it. All other branches are ignored.
copy_drafts()
def copy_drafts():
"""
Copy drafts directory from the old archive structure to the new.
"""
draft_dir = path(source_dir) / course_name / DRAFT_DIR
if draft_dir.isdir():
shutil.copytree(draft_dir, copy_root / DRAFT_DIR)
root = os.listdir(source_dir)
if len(root) != 1 or (path(source_dir) / root[0]).isfile():
raise ValueError("source archive does not have single course directory at top level")
course_name = root[0]
# For this version of the script, we simply convert back and forth between version 0 and 1.
original_version = get_version(path(source_dir) / course_name)
if original_version not in [0, 1]:
raise ValueError("unknown version: " + str(original_version))
desired_version = 1 if original_version is 0 else 0
copy_root = path(target_dir) / course_name
if desired_version == 1:
convert_to_version_1()
else:
convert_to_version_0()
return desired_version
def get_version(course_path):
"""
Return the export format version number for the given
archive directory structure (represented as a path instance).
If the archived file does not correspond to a known export
format, None will be returned.
"""
format_file = course_path / EXPORT_VERSION_FILE
if not format_file.isfile():
return 0
with open(format_file, "r") as f:
data = json.load(f)
if EXPORT_VERSION_KEY in data:
return data[EXPORT_VERSION_KEY]
return None
...@@ -25,9 +25,6 @@ from xblock.test.tools import blocks_are_equivalent ...@@ -25,9 +25,6 @@ from xblock.test.tools import blocks_are_equivalent
from opaque_keys.edx.locations import Location from opaque_keys.edx.locations import Location
from xmodule.modulestore import EdxJSONEncoder from xmodule.modulestore import EdxJSONEncoder
from xmodule.modulestore.xml import XMLModuleStore from xmodule.modulestore.xml import XMLModuleStore
from xmodule.modulestore.xml_exporter import (
convert_between_versions, get_version
)
from xmodule.tests import DATA_DIR from xmodule.tests import DATA_DIR
from xmodule.tests.helpers import directories_equal from xmodule.tests.helpers import directories_equal
from xmodule.x_module import XModuleMixin from xmodule.x_module import XModuleMixin
...@@ -214,173 +211,3 @@ class TestEdxJsonEncoder(unittest.TestCase): ...@@ -214,173 +211,3 @@ class TestEdxJsonEncoder(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.encoder.default({}) self.encoder.default({})
class ConvertExportFormat(unittest.TestCase):
"""
Tests converting between export formats.
"""
def setUp(self):
""" Common setup. """
super(ConvertExportFormat, self).setUp()
# Directory for expanding all the test archives
self.temp_dir = mkdtemp()
self.addCleanup(shutil.rmtree, self.temp_dir)
# Directory where new archive will be created
self.result_dir = path(self.temp_dir) / uuid.uuid4().hex
os.mkdir(self.result_dir)
# Expand all the test archives and store their paths.
self.data_dir = path(__file__).realpath().parent / 'data'
self._version0_nodrafts = None
self._version1_nodrafts = None
self._version0_drafts = None
self._version1_drafts = None
self._version1_drafts_extra_branch = None
self._no_version = None
@property
def version0_nodrafts(self):
"lazily expand this"
if self._version0_nodrafts is None:
self._version0_nodrafts = self._expand_archive('Version0_nodrafts.tar.gz')
return self._version0_nodrafts
@property
def version1_nodrafts(self):
"lazily expand this"
if self._version1_nodrafts is None:
self._version1_nodrafts = self._expand_archive('Version1_nodrafts.tar.gz')
return self._version1_nodrafts
@property
def version0_drafts(self):
"lazily expand this"
if self._version0_drafts is None:
self._version0_drafts = self._expand_archive('Version0_drafts.tar.gz')
return self._version0_drafts
@property
def version1_drafts(self):
"lazily expand this"
if self._version1_drafts is None:
self._version1_drafts = self._expand_archive('Version1_drafts.tar.gz')
return self._version1_drafts
@property
def version1_drafts_extra_branch(self):
"lazily expand this"
if self._version1_drafts_extra_branch is None:
self._version1_drafts_extra_branch = self._expand_archive('Version1_drafts_extra_branch.tar.gz')
return self._version1_drafts_extra_branch
@property
def no_version(self):
"lazily expand this"
if self._no_version is None:
self._no_version = self._expand_archive('NoVersionNumber.tar.gz')
return self._no_version
def _expand_archive(self, name):
""" Expand archive into a directory and return the directory. """
target = path(self.temp_dir) / uuid.uuid4().hex
os.mkdir(target)
with tarfile.open(self.data_dir / name) as tar_file:
tar_file.extractall(path=target)
return target
def test_no_version(self):
""" Test error condition of no version number specified. """
errstring = "unknown version"
with self.assertRaisesRegexp(ValueError, errstring):
convert_between_versions(self.no_version, self.result_dir)
def test_no_published(self):
""" Test error condition of a version 1 archive with no published branch. """
errstring = "version 1 archive must contain a published branch"
no_published = self._expand_archive('Version1_nopublished.tar.gz')
with self.assertRaisesRegexp(ValueError, errstring):
convert_between_versions(no_published, self.result_dir)
def test_empty_course(self):
""" Test error condition of a version 1 archive with no published branch. """
errstring = "source archive does not have single course directory at top level"
empty_course = self._expand_archive('EmptyCourse.tar.gz')
with self.assertRaisesRegexp(ValueError, errstring):
convert_between_versions(empty_course, self.result_dir)
def test_convert_to_1_nodrafts(self):
"""
Test for converting from version 0 of export format to version 1 in a course with no drafts.
"""
self._verify_conversion(self.version0_nodrafts, self.version1_nodrafts)
def test_convert_to_1_drafts(self):
"""
Test for converting from version 0 of export format to version 1 in a course with drafts.
"""
self._verify_conversion(self.version0_drafts, self.version1_drafts)
def test_convert_to_0_nodrafts(self):
"""
Test for converting from version 1 of export format to version 0 in a course with no drafts.
"""
self._verify_conversion(self.version1_nodrafts, self.version0_nodrafts)
def test_convert_to_0_drafts(self):
"""
Test for converting from version 1 of export format to version 0 in a course with drafts.
"""
self._verify_conversion(self.version1_drafts, self.version0_drafts)
def test_convert_to_0_extra_branch(self):
"""
Test for converting from version 1 of export format to version 0 in a course
with drafts and an extra branch.
"""
self._verify_conversion(self.version1_drafts_extra_branch, self.version0_drafts)
def test_equality_function(self):
"""
Check equality function returns False for unequal directories.
"""
self.assertFalse(directories_equal(self.version1_nodrafts, self.version0_nodrafts))
self.assertFalse(directories_equal(self.version1_drafts_extra_branch, self.version1_drafts))
def test_version_0(self):
"""
Check that get_version correctly identifies a version 0 archive (old format).
"""
self.assertEqual(0, self._version_test(self.version0_nodrafts))
def test_version_1(self):
"""
Check that get_version correctly identifies a version 1 archive (new format).
"""
self.assertEqual(1, self._version_test(self.version1_nodrafts))
def test_version_missing(self):
"""
Check that get_version returns None if no version number is specified,
and the archive is not version 0.
"""
self.assertIsNone(self._version_test(self.no_version))
def _version_test(self, archive_dir):
"""
Helper function for version tests.
"""
root = os.listdir(archive_dir)
course_directory = archive_dir / root[0]
return get_version(course_directory)
def _verify_conversion(self, source_archive, comparison_archive):
"""
Helper function for conversion tests.
"""
convert_between_versions(source_archive, self.result_dir)
self.assertTrue(directories_equal(self.result_dir, comparison_archive))
...@@ -25,16 +25,12 @@ class TestCCXModulestoreWrapper(SharedModuleStoreTestCase): ...@@ -25,16 +25,12 @@ class TestCCXModulestoreWrapper(SharedModuleStoreTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(TestCCXModulestoreWrapper, cls).setUpClass() super(TestCCXModulestoreWrapper, cls).setUpClass()
cls.course = course = CourseFactory.create() cls.course = CourseFactory.create()
cls.mooc_start = start = datetime.datetime( start = datetime.datetime(2010, 5, 12, 2, 42, tzinfo=pytz.UTC)
2010, 5, 12, 2, 42, tzinfo=pytz.UTC due = datetime.datetime(2010, 7, 7, 0, 0, tzinfo=pytz.UTC)
)
cls.mooc_due = due = datetime.datetime(
2010, 7, 7, 0, 0, tzinfo=pytz.UTC
)
# Create a course outline # Create a course outline
cls.chapters = chapters = [ cls.chapters = chapters = [
ItemFactory.create(start=start, parent=course) for _ in xrange(2) ItemFactory.create(start=start, parent=cls.course) for _ in xrange(2)
] ]
cls.sequentials = sequentials = [ cls.sequentials = sequentials = [
ItemFactory.create(parent=c) for _ in xrange(2) for c in chapters ItemFactory.create(parent=c) for _ in xrange(2) for c in chapters
...@@ -48,20 +44,24 @@ class TestCCXModulestoreWrapper(SharedModuleStoreTestCase): ...@@ -48,20 +44,24 @@ class TestCCXModulestoreWrapper(SharedModuleStoreTestCase):
ItemFactory.create(parent=v, category='html') for _ in xrange(2) for v in verticals ItemFactory.create(parent=v, category='html') for _ in xrange(2) for v in verticals
] ]
@classmethod
def setUpTestData(cls):
"""
Set up models for the whole TestCase.
"""
cls.user = UserFactory.create()
# Create instructor account
cls.coach = AdminFactory.create()
def setUp(self): def setUp(self):
""" """
Set up tests Set up tests
""" """
super(TestCCXModulestoreWrapper, self).setUp() super(TestCCXModulestoreWrapper, self).setUp()
self.user = UserFactory.create()
# Create instructor account
coach = AdminFactory.create()
self.ccx = ccx = CustomCourseForEdX( self.ccx = ccx = CustomCourseForEdX(
course_id=self.course.id, course_id=self.course.id,
display_name='Test CCX', display_name='Test CCX',
coach=coach coach=self.coach
) )
ccx.save() ccx.save()
...@@ -132,12 +132,13 @@ class TestCCXModulestoreWrapper(SharedModuleStoreTestCase): ...@@ -132,12 +132,13 @@ class TestCCXModulestoreWrapper(SharedModuleStoreTestCase):
def test_publication_api(self): def test_publication_api(self):
"""verify that we can correctly discern a published item by ccx key""" """verify that we can correctly discern a published item by ccx key"""
for expected in self.blocks: with self.store.bulk_operations(self.ccx_locator):
block_key = self.ccx_locator.make_usage_key( for expected in self.blocks:
expected.location.block_type, expected.location.block_id block_key = self.ccx_locator.make_usage_key(
) expected.location.block_type, expected.location.block_id
self.assertTrue(self.store.has_published_version(expected)) )
self.store.unpublish(block_key, self.user.id) self.assertTrue(self.store.has_published_version(expected))
self.assertFalse(self.store.has_published_version(expected)) self.store.unpublish(block_key, self.user.id)
self.store.publish(block_key, self.user.id) self.assertFalse(self.store.has_published_version(expected))
self.assertTrue(self.store.has_published_version(expected)) self.store.publish(block_key, self.user.id)
self.assertTrue(self.store.has_published_version(expected))
...@@ -13,7 +13,7 @@ from lms.djangoapps.courseware.tests.test_field_overrides import inject_field_ov ...@@ -13,7 +13,7 @@ from lms.djangoapps.courseware.tests.test_field_overrides import inject_field_ov
from request_cache.middleware import RequestCache from request_cache.middleware import RequestCache
from student.tests.factories import AdminFactory from student.tests.factories import AdminFactory
from xmodule.modulestore.tests.django_utils import ( from xmodule.modulestore.tests.django_utils import (
ModuleStoreTestCase, SharedModuleStoreTestCase,
TEST_DATA_SPLIT_MODULESTORE) TEST_DATA_SPLIT_MODULESTORE)
from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory
...@@ -26,26 +26,25 @@ from lms.djangoapps.ccx.tests.utils import flatten, iter_blocks ...@@ -26,26 +26,25 @@ from lms.djangoapps.ccx.tests.utils import flatten, iter_blocks
@attr('shard_1') @attr('shard_1')
@override_settings(FIELD_OVERRIDE_PROVIDERS=( @override_settings(FIELD_OVERRIDE_PROVIDERS=(
'ccx.overrides.CustomCoursesForEdxOverrideProvider',)) 'ccx.overrides.CustomCoursesForEdxOverrideProvider',))
class TestFieldOverrides(ModuleStoreTestCase): class TestFieldOverrides(SharedModuleStoreTestCase):
""" """
Make sure field overrides behave in the expected manner. Make sure field overrides behave in the expected manner.
""" """
MODULESTORE = TEST_DATA_SPLIT_MODULESTORE MODULESTORE = TEST_DATA_SPLIT_MODULESTORE
def setUp(self): @classmethod
def setUpClass(cls):
""" """
Set up tests Course is created here and shared by all the class's tests.
""" """
super(TestFieldOverrides, self).setUp() super(TestFieldOverrides, cls).setUpClass()
self.course = course = CourseFactory.create() cls.course = CourseFactory.create()
self.course.enable_ccx = True cls.course.enable_ccx = True
# Create a course outline # Create a course outline
self.mooc_start = start = datetime.datetime( start = datetime.datetime(2010, 5, 12, 2, 42, tzinfo=pytz.UTC)
2010, 5, 12, 2, 42, tzinfo=pytz.UTC) due = datetime.datetime(2010, 7, 7, 0, 0, tzinfo=pytz.UTC)
self.mooc_due = due = datetime.datetime( chapters = [ItemFactory.create(start=start, parent=cls.course)
2010, 7, 7, 0, 0, tzinfo=pytz.UTC)
chapters = [ItemFactory.create(start=start, parent=course)
for _ in xrange(2)] for _ in xrange(2)]
sequentials = flatten([ sequentials = flatten([
[ItemFactory.create(parent=chapter) for _ in xrange(2)] [ItemFactory.create(parent=chapter) for _ in xrange(2)]
...@@ -57,8 +56,14 @@ class TestFieldOverrides(ModuleStoreTestCase): ...@@ -57,8 +56,14 @@ class TestFieldOverrides(ModuleStoreTestCase):
[ItemFactory.create(parent=vertical) for _ in xrange(2)] [ItemFactory.create(parent=vertical) for _ in xrange(2)]
for vertical in verticals]) for vertical in verticals])
def setUp(self):
"""
Set up tests
"""
super(TestFieldOverrides, self).setUp()
self.ccx = ccx = CustomCourseForEdX( self.ccx = ccx = CustomCourseForEdX(
course_id=course.id, course_id=self.course.id,
display_name='Test CCX', display_name='Test CCX',
coach=AdminFactory.create()) coach=AdminFactory.create())
ccx.save() ccx.save()
...@@ -70,7 +75,7 @@ class TestFieldOverrides(ModuleStoreTestCase): ...@@ -70,7 +75,7 @@ class TestFieldOverrides(ModuleStoreTestCase):
self.addCleanup(RequestCache.clear_request_cache) self.addCleanup(RequestCache.clear_request_cache)
inject_field_overrides(iter_blocks(ccx.course), course, AdminFactory.create()) inject_field_overrides(iter_blocks(ccx.course), self.course, AdminFactory.create())
def cleanup_provider_classes(): def cleanup_provider_classes():
""" """
......
...@@ -7,7 +7,7 @@ from nose.plugins.attrib import attr ...@@ -7,7 +7,7 @@ from nose.plugins.attrib import attr
from django.test.utils import override_settings from django.test.utils import override_settings
from xblock.field_data import DictFieldData from xblock.field_data import DictFieldData
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from ..field_overrides import ( from ..field_overrides import (
disable_overrides, disable_overrides,
...@@ -23,14 +23,21 @@ TESTUSER = "testuser" ...@@ -23,14 +23,21 @@ TESTUSER = "testuser"
@attr('shard_1') @attr('shard_1')
@override_settings(FIELD_OVERRIDE_PROVIDERS=( @override_settings(FIELD_OVERRIDE_PROVIDERS=(
'courseware.tests.test_field_overrides.TestOverrideProvider',)) 'courseware.tests.test_field_overrides.TestOverrideProvider',))
class OverrideFieldDataTests(ModuleStoreTestCase): class OverrideFieldDataTests(SharedModuleStoreTestCase):
""" """
Tests for `OverrideFieldData`. Tests for `OverrideFieldData`.
""" """
@classmethod
def setUpClass(cls):
"""
Course is created here and shared by all the class's tests.
"""
super(OverrideFieldDataTests, cls).setUpClass()
cls.course = CourseFactory.create(enable_ccx=True)
def setUp(self): def setUp(self):
super(OverrideFieldDataTests, self).setUp() super(OverrideFieldDataTests, self).setUp()
self.course = CourseFactory.create(enable_ccx=True)
OverrideFieldData.provider_classes = None OverrideFieldData.provider_classes = None
def tearDown(self): def tearDown(self):
......
...@@ -11,56 +11,68 @@ from django.test.utils import override_settings ...@@ -11,56 +11,68 @@ from django.test.utils import override_settings
from courseware.tests.helpers import LoginEnrollmentTestCase from courseware.tests.helpers import LoginEnrollmentTestCase
from courseware.tests.factories import GlobalStaffFactory from courseware.tests.factories import GlobalStaffFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from student.tests.factories import UserFactory
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
@attr('shard_1') @attr('shard_1')
class TestNavigation(ModuleStoreTestCase, LoginEnrollmentTestCase): class TestNavigation(SharedModuleStoreTestCase, LoginEnrollmentTestCase):
""" """
Check that navigation state is saved properly. Check that navigation state is saved properly.
""" """
STUDENT_INFO = [('view@test.com', 'foo'), ('view2@test.com', 'foo')] STUDENT_INFO = [('view@test.com', 'foo'), ('view2@test.com', 'foo')]
def setUp(self): @classmethod
super(TestNavigation, self).setUp() def setUpClass(cls):
# pylint: disable=super-method-not-called
self.test_course = CourseFactory.create() with super(TestNavigation, cls).setUpClassAndTestData():
self.course = CourseFactory.create() cls.test_course = CourseFactory.create()
self.chapter0 = ItemFactory.create(parent=self.course, cls.test_course_proctored = CourseFactory.create()
display_name='Overview') cls.course = CourseFactory.create()
self.chapter9 = ItemFactory.create(parent=self.course,
display_name='factory_chapter') @classmethod
self.section0 = ItemFactory.create(parent=self.chapter0, def setUpTestData(cls):
display_name='Welcome') cls.chapter0 = ItemFactory.create(parent=cls.course,
self.section9 = ItemFactory.create(parent=self.chapter9, display_name='Overview')
display_name='factory_section') cls.chapter9 = ItemFactory.create(parent=cls.course,
self.unit0 = ItemFactory.create(parent=self.section0, display_name='factory_chapter')
display_name='New Unit') cls.section0 = ItemFactory.create(parent=cls.chapter0,
display_name='Welcome')
self.chapterchrome = ItemFactory.create(parent=self.course, cls.section9 = ItemFactory.create(parent=cls.chapter9,
display_name='Chrome') display_name='factory_section')
self.chromelesssection = ItemFactory.create(parent=self.chapterchrome, cls.unit0 = ItemFactory.create(parent=cls.section0,
display_name='chromeless', display_name='New Unit')
chrome='none')
self.accordionsection = ItemFactory.create(parent=self.chapterchrome, cls.chapterchrome = ItemFactory.create(parent=cls.course,
display_name='accordion', display_name='Chrome')
chrome='accordion') cls.chromelesssection = ItemFactory.create(parent=cls.chapterchrome,
self.tabssection = ItemFactory.create(parent=self.chapterchrome, display_name='chromeless',
display_name='tabs', chrome='none')
chrome='tabs') cls.accordionsection = ItemFactory.create(parent=cls.chapterchrome,
self.defaultchromesection = ItemFactory.create( display_name='accordion',
parent=self.chapterchrome, chrome='accordion')
cls.tabssection = ItemFactory.create(parent=cls.chapterchrome,
display_name='tabs',
chrome='tabs')
cls.defaultchromesection = ItemFactory.create(
parent=cls.chapterchrome,
display_name='defaultchrome', display_name='defaultchrome',
) )
self.fullchromesection = ItemFactory.create(parent=self.chapterchrome, cls.fullchromesection = ItemFactory.create(parent=cls.chapterchrome,
display_name='fullchrome', display_name='fullchrome',
chrome='accordion,tabs') chrome='accordion,tabs')
self.tabtest = ItemFactory.create(parent=self.chapterchrome, cls.tabtest = ItemFactory.create(parent=cls.chapterchrome,
display_name='progress_tab', display_name='progress_tab',
default_tab='progress') default_tab='progress')
cls.staff_user = GlobalStaffFactory()
cls.user = UserFactory()
def setUp(self):
super(TestNavigation, self).setUp()
# Create student accounts and activate them. # Create student accounts and activate them.
for i in range(len(self.STUDENT_INFO)): for i in range(len(self.STUDENT_INFO)):
...@@ -69,8 +81,6 @@ class TestNavigation(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -69,8 +81,6 @@ class TestNavigation(ModuleStoreTestCase, LoginEnrollmentTestCase):
self.create_account(username, email, password) self.create_account(username, email, password)
self.activate_user(email) self.activate_user(email)
self.staff_user = GlobalStaffFactory()
def assertTabActive(self, tabname, response): def assertTabActive(self, tabname, response):
''' Check if the progress tab is active in the tab set ''' ''' Check if the progress tab is active in the tab set '''
for line in response.content.split('\n'): for line in response.content.split('\n'):
...@@ -278,9 +288,9 @@ class TestNavigation(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -278,9 +288,9 @@ class TestNavigation(ModuleStoreTestCase, LoginEnrollmentTestCase):
email, password = self.STUDENT_INFO[0] email, password = self.STUDENT_INFO[0]
self.login(email, password) self.login(email, password)
self.enroll(self.test_course, True) self.enroll(self.test_course_proctored, True)
test_course_id = self.test_course.id.to_deprecated_string() test_course_id = self.test_course_proctored.id.to_deprecated_string()
with patch.dict(settings.FEATURES, {'ENABLE_SPECIAL_EXAMS': False}): with patch.dict(settings.FEATURES, {'ENABLE_SPECIAL_EXAMS': False}):
url = reverse( url = reverse(
...@@ -302,10 +312,10 @@ class TestNavigation(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -302,10 +312,10 @@ class TestNavigation(ModuleStoreTestCase, LoginEnrollmentTestCase):
# now set up a course which is proctored enabled # now set up a course which is proctored enabled
self.test_course.enable_proctored_exams = True self.test_course_proctored.enable_proctored_exams = True
self.test_course.save() self.test_course_proctored.save()
modulestore().update_item(self.test_course, self.user.id) modulestore().update_item(self.test_course_proctored, self.user.id)
resp = self.client.get(url) resp = self.client.get(url)
......
...@@ -18,7 +18,7 @@ import dashboard.git_import as git_import ...@@ -18,7 +18,7 @@ import dashboard.git_import as git_import
from dashboard.git_import import GitImportError from dashboard.git_import import GitImportError
from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore.tests.mongo_connection import MONGO_PORT_NUM, MONGO_HOST from xmodule.modulestore.tests.mongo_connection import MONGO_PORT_NUM, MONGO_HOST
...@@ -37,7 +37,7 @@ FEATURES_WITH_SSL_AUTH['AUTH_USE_CERTIFICATES'] = True ...@@ -37,7 +37,7 @@ FEATURES_WITH_SSL_AUTH['AUTH_USE_CERTIFICATES'] = True
@override_settings(MONGODB_LOG=TEST_MONGODB_LOG) @override_settings(MONGODB_LOG=TEST_MONGODB_LOG)
@unittest.skipUnless(settings.FEATURES.get('ENABLE_SYSADMIN_DASHBOARD'), @unittest.skipUnless(settings.FEATURES.get('ENABLE_SYSADMIN_DASHBOARD'),
"ENABLE_SYSADMIN_DASHBOARD not set") "ENABLE_SYSADMIN_DASHBOARD not set")
class TestGitAddCourse(ModuleStoreTestCase): class TestGitAddCourse(SharedModuleStoreTestCase):
""" """
Tests the git_add_course management command for proper functions. Tests the git_add_course management command for proper functions.
""" """
......
...@@ -59,10 +59,6 @@ class SysadminDashboardView(TemplateView): ...@@ -59,10 +59,6 @@ class SysadminDashboardView(TemplateView):
""" """
self.def_ms = modulestore() self.def_ms = modulestore()
self.is_using_mongo = True
if self.def_ms.get_modulestore_type(None) == 'xml':
self.is_using_mongo = False
self.msg = u'' self.msg = u''
self.datatable = [] self.datatable = []
super(SysadminDashboardView, self).__init__(**kwargs) super(SysadminDashboardView, self).__init__(**kwargs)
...@@ -374,10 +370,7 @@ class Courses(SysadminDashboardView): ...@@ -374,10 +370,7 @@ class Courses(SysadminDashboardView):
return _("The git repo location should end with '.git', " return _("The git repo location should end with '.git', "
"and be a valid url") "and be a valid url")
if self.is_using_mongo: return self.import_mongo_course(gitloc, branch)
return self.import_mongo_course(gitloc, branch)
return self.import_xml_course(gitloc, branch)
def import_mongo_course(self, gitloc, branch): def import_mongo_course(self, gitloc, branch):
""" """
...@@ -429,80 +422,6 @@ class Courses(SysadminDashboardView): ...@@ -429,80 +422,6 @@ class Courses(SysadminDashboardView):
msg += u"<pre>{0}</pre>".format(escape(ret)) msg += u"<pre>{0}</pre>".format(escape(ret))
return msg return msg
def import_xml_course(self, gitloc, branch):
"""Imports a git course into the XMLModuleStore"""
msg = u''
if not getattr(settings, 'GIT_IMPORT_WITH_XMLMODULESTORE', False):
# Translators: "GIT_IMPORT_WITH_XMLMODULESTORE" is a variable name.
# "XMLModuleStore" and "MongoDB" are database systems. You should not
# translate these names.
return _('Refusing to import. GIT_IMPORT_WITH_XMLMODULESTORE is '
'not turned on, and it is generally not safe to import '
'into an XMLModuleStore with multithreaded. We '
'recommend you enable the MongoDB based module store '
'instead, unless this is a development environment.')
cdir = (gitloc.rsplit('/', 1)[1])[:-4]
gdir = settings.DATA_DIR / cdir
if os.path.exists(gdir):
msg += _("The course {0} already exists in the data directory! "
"(reloading anyway)").format(cdir)
cmd = ['git', 'pull', ]
cwd = gdir
else:
cmd = ['git', 'clone', gitloc, ]
cwd = settings.DATA_DIR
cwd = os.path.abspath(cwd)
try:
cmd_output = escape(
subprocess.check_output(cmd, stderr=subprocess.STDOUT, cwd=cwd)
)
except subprocess.CalledProcessError as ex:
log.exception('Git pull or clone output was: %r', ex.output)
# Translators: unable to download the course content from
# the source git repository. Clone occurs if this is brand
# new, and pull is when it is being updated from the
# source.
return _('Unable to clone or pull repository. Please check '
'your url. Output was: {0!r}').format(ex.output)
msg += u'<pre>{0}</pre>'.format(cmd_output)
if not os.path.exists(gdir):
msg += _('Failed to clone repository to {directory_name}').format(directory_name=gdir)
return msg
# Change branch if specified
if branch:
try:
git_import.switch_branch(branch, gdir)
except GitImportError as ex:
return str(ex)
# Translators: This is a git repository branch, which is a
# specific version of a courses content
msg += u'<p>{0}</p>'.format(
_('Successfully switched to branch: '
'{branch_name}').format(branch_name=branch))
self.def_ms.try_load_course(os.path.abspath(gdir))
errlog = self.def_ms.errored_courses.get(cdir, '')
if errlog:
msg += u'<hr width="50%"><pre>{0}</pre>'.format(escape(errlog))
else:
course = self.def_ms.courses[os.path.abspath(gdir)]
msg += _('Loaded course {course_name}<br/>Errors:').format(
course_name="{} {}".format(cdir, course.display_name)
)
errors = self.def_ms.get_course_errors(course.id)
if not errors:
msg += u'None'
else:
msg += u'<ul>'
for (summary, err) in errors:
msg += u'<li><pre>{0}: {1}</pre></li>'.format(escape(summary),
escape(err))
msg += u'</ul>'
return msg
def make_datatable(self): def make_datatable(self):
"""Creates course information datatable""" """Creates course information datatable"""
......
...@@ -26,7 +26,8 @@ from xmodule.modulestore import ModuleStoreEnum ...@@ -26,7 +26,8 @@ from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ( from xmodule.modulestore.tests.django_utils import (
ModuleStoreTestCase, ModuleStoreTestCase,
TEST_DATA_MONGO_MODULESTORE SharedModuleStoreTestCase,
TEST_DATA_MONGO_MODULESTORE,
) )
from xmodule.modulestore.tests.factories import check_mongo_calls, CourseFactory, ItemFactory from xmodule.modulestore.tests.factories import check_mongo_calls, CourseFactory, ItemFactory
...@@ -1282,13 +1283,20 @@ class CommentsServiceRequestHeadersTestCase(UrlResetMixin, ModuleStoreTestCase): ...@@ -1282,13 +1283,20 @@ class CommentsServiceRequestHeadersTestCase(UrlResetMixin, ModuleStoreTestCase):
self.assert_all_calls_have_header(mock_request, "X-Edx-Api-Key", "test_api_key") self.assert_all_calls_have_header(mock_request, "X-Edx-Api-Key", "test_api_key")
class InlineDiscussionUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): class InlineDiscussionUnicodeTestCase(SharedModuleStoreTestCase, UnicodeTestMixin):
def setUp(self):
super(InlineDiscussionUnicodeTestCase, self).setUp()
self.course = CourseFactory.create() @classmethod
self.student = UserFactory.create() def setUpClass(cls):
CourseEnrollmentFactory(user=self.student, course_id=self.course.id) # pylint: disable=super-method-not-called
with super(InlineDiscussionUnicodeTestCase, cls).setUpClassAndTestData():
cls.course = CourseFactory.create()
@classmethod
def setUpTestData(cls):
super(InlineDiscussionUnicodeTestCase, cls).setUpTestData()
cls.student = UserFactory.create()
CourseEnrollmentFactory(user=cls.student, course_id=cls.course.id)
@patch('lms.lib.comment_client.utils.requests.request', autospec=True) @patch('lms.lib.comment_client.utils.requests.request', autospec=True)
def _test_unicode_data(self, text, mock_request): def _test_unicode_data(self, text, mock_request):
...@@ -1305,13 +1313,19 @@ class InlineDiscussionUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): ...@@ -1305,13 +1313,19 @@ class InlineDiscussionUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin):
self.assertEqual(response_data["discussion_data"][0]["body"], text) self.assertEqual(response_data["discussion_data"][0]["body"], text)
class ForumFormDiscussionUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): class ForumFormDiscussionUnicodeTestCase(SharedModuleStoreTestCase, UnicodeTestMixin):
def setUp(self): @classmethod
super(ForumFormDiscussionUnicodeTestCase, self).setUp() def setUpClass(cls):
# pylint: disable=super-method-not-called
with super(ForumFormDiscussionUnicodeTestCase, cls).setUpClassAndTestData():
cls.course = CourseFactory.create()
self.course = CourseFactory.create() @classmethod
self.student = UserFactory.create() def setUpTestData(cls):
CourseEnrollmentFactory(user=self.student, course_id=self.course.id) super(ForumFormDiscussionUnicodeTestCase, cls).setUpTestData()
cls.student = UserFactory.create()
CourseEnrollmentFactory(user=cls.student, course_id=cls.course.id)
@patch('lms.lib.comment_client.utils.requests.request', autospec=True) @patch('lms.lib.comment_client.utils.requests.request', autospec=True)
def _test_unicode_data(self, text, mock_request): def _test_unicode_data(self, text, mock_request):
...@@ -1377,13 +1391,20 @@ class ForumDiscussionXSSTestCase(UrlResetMixin, ModuleStoreTestCase): ...@@ -1377,13 +1391,20 @@ class ForumDiscussionXSSTestCase(UrlResetMixin, ModuleStoreTestCase):
self.assertNotIn(malicious_code, resp.content) self.assertNotIn(malicious_code, resp.content)
class ForumDiscussionSearchUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): class ForumDiscussionSearchUnicodeTestCase(SharedModuleStoreTestCase, UnicodeTestMixin):
def setUp(self):
super(ForumDiscussionSearchUnicodeTestCase, self).setUp()
self.course = CourseFactory.create() @classmethod
self.student = UserFactory.create() def setUpClass(cls):
CourseEnrollmentFactory(user=self.student, course_id=self.course.id) # pylint: disable=super-method-not-called
with super(ForumDiscussionSearchUnicodeTestCase, cls).setUpClassAndTestData():
cls.course = CourseFactory.create()
@classmethod
def setUpTestData(cls):
super(ForumDiscussionSearchUnicodeTestCase, cls).setUpTestData()
cls.student = UserFactory.create()
CourseEnrollmentFactory(user=cls.student, course_id=cls.course.id)
@patch('lms.lib.comment_client.utils.requests.request', autospec=True) @patch('lms.lib.comment_client.utils.requests.request', autospec=True)
def _test_unicode_data(self, text, mock_request): def _test_unicode_data(self, text, mock_request):
...@@ -1403,13 +1424,20 @@ class ForumDiscussionSearchUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin ...@@ -1403,13 +1424,20 @@ class ForumDiscussionSearchUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin
self.assertEqual(response_data["discussion_data"][0]["body"], text) self.assertEqual(response_data["discussion_data"][0]["body"], text)
class SingleThreadUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): class SingleThreadUnicodeTestCase(SharedModuleStoreTestCase, UnicodeTestMixin):
def setUp(self):
super(SingleThreadUnicodeTestCase, self).setUp()
self.course = CourseFactory.create(discussion_topics={'dummy_discussion_id': {'id': 'dummy_discussion_id'}}) @classmethod
self.student = UserFactory.create() def setUpClass(cls):
CourseEnrollmentFactory(user=self.student, course_id=self.course.id) # pylint: disable=super-method-not-called
with super(SingleThreadUnicodeTestCase, cls).setUpClassAndTestData():
cls.course = CourseFactory.create(discussion_topics={'dummy_discussion_id': {'id': 'dummy_discussion_id'}})
@classmethod
def setUpTestData(cls):
super(SingleThreadUnicodeTestCase, cls).setUpTestData()
cls.student = UserFactory.create()
CourseEnrollmentFactory(user=cls.student, course_id=cls.course.id)
@patch('lms.lib.comment_client.utils.requests.request', autospec=True) @patch('lms.lib.comment_client.utils.requests.request', autospec=True)
def _test_unicode_data(self, text, mock_request): def _test_unicode_data(self, text, mock_request):
...@@ -1426,13 +1454,20 @@ class SingleThreadUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): ...@@ -1426,13 +1454,20 @@ class SingleThreadUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin):
self.assertEqual(response_data["content"]["body"], text) self.assertEqual(response_data["content"]["body"], text)
class UserProfileUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): class UserProfileUnicodeTestCase(SharedModuleStoreTestCase, UnicodeTestMixin):
def setUp(self):
super(UserProfileUnicodeTestCase, self).setUp()
self.course = CourseFactory.create() @classmethod
self.student = UserFactory.create() def setUpClass(cls):
CourseEnrollmentFactory(user=self.student, course_id=self.course.id) # pylint: disable=super-method-not-called
with super(UserProfileUnicodeTestCase, cls).setUpClassAndTestData():
cls.course = CourseFactory.create()
@classmethod
def setUpTestData(cls):
super(UserProfileUnicodeTestCase, cls).setUpTestData()
cls.student = UserFactory.create()
CourseEnrollmentFactory(user=cls.student, course_id=cls.course.id)
@patch('lms.lib.comment_client.utils.requests.request', autospec=True) @patch('lms.lib.comment_client.utils.requests.request', autospec=True)
def _test_unicode_data(self, text, mock_request): def _test_unicode_data(self, text, mock_request):
...@@ -1448,13 +1483,20 @@ class UserProfileUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): ...@@ -1448,13 +1483,20 @@ class UserProfileUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin):
self.assertEqual(response_data["discussion_data"][0]["body"], text) self.assertEqual(response_data["discussion_data"][0]["body"], text)
class FollowedThreadsUnicodeTestCase(ModuleStoreTestCase, UnicodeTestMixin): class FollowedThreadsUnicodeTestCase(SharedModuleStoreTestCase, UnicodeTestMixin):
def setUp(self):
super(FollowedThreadsUnicodeTestCase, self).setUp()
self.course = CourseFactory.create() @classmethod
self.student = UserFactory.create() def setUpClass(cls):
CourseEnrollmentFactory(user=self.student, course_id=self.course.id) # pylint: disable=super-method-not-called
with super(FollowedThreadsUnicodeTestCase, cls).setUpClassAndTestData():
cls.course = CourseFactory.create()
@classmethod
def setUpTestData(cls):
super(FollowedThreadsUnicodeTestCase, cls).setUpTestData()
cls.student = UserFactory.create()
CourseEnrollmentFactory(user=cls.student, course_id=cls.course.id)
@patch('lms.lib.comment_client.utils.requests.request', autospec=True) @patch('lms.lib.comment_client.utils.requests.request', autospec=True)
def _test_unicode_data(self, text, mock_request): def _test_unicode_data(self, text, mock_request):
......
...@@ -12,7 +12,7 @@ from openedx.core.lib.django_startup import autostartup ...@@ -12,7 +12,7 @@ from openedx.core.lib.django_startup import autostartup
import edxmako import edxmako
import logging import logging
import analytics import analytics
from monkey_patch import third_party_auth from monkey_patch import third_party_auth, django_db_models_options
import xmodule.x_module import xmodule.x_module
...@@ -29,6 +29,7 @@ def run(): ...@@ -29,6 +29,7 @@ def run():
Executed during django startup Executed during django startup
""" """
third_party_auth.patch() third_party_auth.patch()
django_db_models_options.patch()
# To override the settings before executing the autostartup() for python-social-auth # To override the settings before executing the autostartup() for python-social-auth
if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH', False): if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH', False):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment