Commit e2d98520 by Brian Wilson

Move updates for InstructorTask into BaseInstructorTask abstract class.

parent c01fa459
...@@ -21,16 +21,18 @@ of the query for traversing StudentModule objects. ...@@ -21,16 +21,18 @@ of the query for traversing StudentModule objects.
""" """
from celery import task from celery import task
from functools import partial from functools import partial
from instructor_task.tasks_helper import (run_main_task, from instructor_task.tasks_helper import (
run_main_task,
BaseInstructorTask,
perform_module_state_update, perform_module_state_update,
rescore_problem_module_state, rescore_problem_module_state,
reset_attempts_module_state, reset_attempts_module_state,
delete_problem_module_state, delete_problem_module_state,
) )
from bulk_email.tasks import perform_delegate_email_batches from bulk_email.tasks import perform_delegate_email_batches
@task @task(base=BaseInstructorTask)
def rescore_problem(entry_id, xmodule_instance_args): def rescore_problem(entry_id, xmodule_instance_args):
"""Rescores a problem in a course, for all students or one specific student. """Rescores a problem in a course, for all students or one specific student.
...@@ -59,7 +61,7 @@ def rescore_problem(entry_id, xmodule_instance_args): ...@@ -59,7 +61,7 @@ def rescore_problem(entry_id, xmodule_instance_args):
return run_main_task(entry_id, visit_fcn, action_name) return run_main_task(entry_id, visit_fcn, action_name)
@task @task(base=BaseInstructorTask)
def reset_problem_attempts(entry_id, xmodule_instance_args): def reset_problem_attempts(entry_id, xmodule_instance_args):
"""Resets problem attempts to zero for a particular problem for all students in a course. """Resets problem attempts to zero for a particular problem for all students in a course.
...@@ -80,7 +82,7 @@ def reset_problem_attempts(entry_id, xmodule_instance_args): ...@@ -80,7 +82,7 @@ def reset_problem_attempts(entry_id, xmodule_instance_args):
return run_main_task(entry_id, visit_fcn, action_name) return run_main_task(entry_id, visit_fcn, action_name)
@task @task(base=BaseInstructorTask)
def delete_problem_state(entry_id, xmodule_instance_args): def delete_problem_state(entry_id, xmodule_instance_args):
"""Deletes problem state entirely for all students on a particular problem in a course. """Deletes problem state entirely for all students on a particular problem in a course.
...@@ -101,7 +103,7 @@ def delete_problem_state(entry_id, xmodule_instance_args): ...@@ -101,7 +103,7 @@ def delete_problem_state(entry_id, xmodule_instance_args):
return run_main_task(entry_id, visit_fcn, action_name) return run_main_task(entry_id, visit_fcn, action_name)
@task @task(base=BaseInstructorTask)
def send_bulk_course_email(entry_id, xmodule_instance_args): def send_bulk_course_email(entry_id, xmodule_instance_args):
"""Sends emails to in a course. """Sends emails to in a course.
...@@ -116,4 +118,4 @@ def send_bulk_course_email(entry_id, xmodule_instance_args): ...@@ -116,4 +118,4 @@ def send_bulk_course_email(entry_id, xmodule_instance_args):
""" """
action_name = 'emailed' action_name = 'emailed'
visit_fcn = perform_delegate_email_batches visit_fcn = perform_delegate_email_batches
return run_main_task(entry_id, visit_fcn, action_name, spawns_subtasks=True) return run_main_task(entry_id, visit_fcn, action_name)
...@@ -5,10 +5,8 @@ running state of a course. ...@@ -5,10 +5,8 @@ running state of a course.
""" """
import json import json
from time import time from time import time
from sys import exc_info
from traceback import format_exc
from celery import current_task from celery import Task, current_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from celery.states import SUCCESS, FAILURE from celery.states import SUCCESS, FAILURE
...@@ -37,6 +35,66 @@ UPDATE_STATUS_FAILED = 'failed' ...@@ -37,6 +35,66 @@ UPDATE_STATUS_FAILED = 'failed'
UPDATE_STATUS_SKIPPED = 'skipped' UPDATE_STATUS_SKIPPED = 'skipped'
class BaseInstructorTask(Task):
"""
Base task class for use with InstructorTask models.
Permits updating information about task in corresponding InstructorTask for monitoring purposes.
Assumes that the entry_id of the InstructorTask model is the first argument to the task.
"""
abstract = True
def on_success(self, task_progress, task_id, args, kwargs):
"""
Update InstructorTask object corresponding to this task with info about success.
Updates task_output and task_state. But it shouldn't actually do anything
if the task is only creating subtasks to actually do the work.
"""
TASK_LOG.info('Task success returned: %r' % (self.request, ))
# We should be able to find the InstructorTask object to update
# based on the task_id here, without having to dig into the
# original args to the task. On the other hand, the entry_id
# is the first value passed to all such args, so we'll use that.
# And we assume that it exists, else we would already have had a failure.
entry_id = args[0]
entry = InstructorTask.objects.get(pk=entry_id)
# Check to see if any subtasks had been defined as part of this task.
# If not, then we know that we're done. (If so, let the subtasks
# handle updating task_state themselves.)
if len(entry.subtasks) == 0:
entry.task_output = InstructorTask.create_output_for_success(task_progress)
entry.task_state = SUCCESS
entry.save_now()
def on_failure(self, exc, task_id, args, kwargs, einfo):
"""
Update InstructorTask object corresponding to this task with info about failure.
Fetches and updates exception and traceback information on failure.
"""
TASK_LOG.info('Task failure returned: %r' % (self.request, ))
entry_id = args[0]
try:
entry = InstructorTask.objects.get(pk=entry_id)
except InstructorTask.DoesNotExist:
# if the InstructorTask object does not exist, then there's no point
# trying to update it.
pass
else:
TASK_LOG.warning("background task (%s) failed: %s %s", task_id, einfo.exception, einfo.traceback)
entry.task_output = InstructorTask.create_output_for_failure(einfo.exception, einfo.traceback)
entry.task_state = FAILURE
entry.save_now()
def on_retry(self, exc, task_id, args, kwargs, einfo):
# We don't expect this to be called for top-level tasks, at the moment....
# If it were, not sure what kind of status to report for it.
# But it would be good to know that it's being called, so at least log it.
TASK_LOG.info('Task retry returned: %r' % (self.request, ))
class UpdateProblemModuleStateError(Exception): class UpdateProblemModuleStateError(Exception):
""" """
Error signaling a fatal condition while updating problem modules. Error signaling a fatal condition while updating problem modules.
...@@ -162,7 +220,7 @@ def perform_module_state_update(update_fcn, filter_fcn, _entry_id, course_id, ta ...@@ -162,7 +220,7 @@ def perform_module_state_update(update_fcn, filter_fcn, _entry_id, course_id, ta
return task_progress return task_progress
def run_main_task(entry_id, task_fcn, action_name, spawns_subtasks=False): def run_main_task(entry_id, task_fcn, action_name):
""" """
Applies the `task_fcn` to the arguments defined in `entry_id` InstructorTask. Applies the `task_fcn` to the arguments defined in `entry_id` InstructorTask.
...@@ -221,9 +279,6 @@ def run_main_task(entry_id, task_fcn, action_name, spawns_subtasks=False): ...@@ -221,9 +279,6 @@ def run_main_task(entry_id, task_fcn, action_name, spawns_subtasks=False):
TASK_LOG.info('Starting update (nothing %s yet): %s', action_name, task_info_string) TASK_LOG.info('Starting update (nothing %s yet): %s', action_name, task_info_string)
# Now that we have an entry we can try to catch failures:
task_progress = None
try:
# Check that the task_id submitted in the InstructorTask matches the current task # Check that the task_id submitted in the InstructorTask matches the current task
# that is running. # that is running.
request_task_id = _get_current_task().request.id request_task_id = _get_current_task().request.id
...@@ -237,49 +292,6 @@ def run_main_task(entry_id, task_fcn, action_name, spawns_subtasks=False): ...@@ -237,49 +292,6 @@ def run_main_task(entry_id, task_fcn, action_name, spawns_subtasks=False):
with dog_stats_api.timer('instructor_tasks.time.overall', tags=['action:{name}'.format(name=action_name)]): with dog_stats_api.timer('instructor_tasks.time.overall', tags=['action:{name}'.format(name=action_name)]):
task_progress = task_fcn(entry_id, course_id, task_input, action_name) task_progress = task_fcn(entry_id, course_id, task_input, action_name)
# If we get here, we assume we've succeeded, so update the InstructorTask entry in anticipation.
# But we do this within the try, in case creating the task_output causes an exception to be
# raised.
# TODO: This is not the case if there are outstanding subtasks that were spawned asynchronously
# as part of the main task. There is probably some way to represent this more elegantly, but for
# now, we will just use an explicit flag.
if spawns_subtasks:
# TODO: UPDATE THIS.
# we change the rules here. If it's a task with subtasks running, then we
# explicitly set its state, with the idea that progress will be updated
# directly into the InstructorTask object, rather than into the parent task's
# AsyncResult object. This is because we have to write to the InstructorTask
# object anyway, so we may as well put status in there. And because multiple
# clients are writing to it, we need the locking that a DB can provide, rather
# than the speed that the AsyncResult provides.
# So we need to change the logic of the monitor to pull status from the
# InstructorTask directly when the state is PROGRESS, and to pull from the
# AsyncResult when it's running but not marked as in PROGRESS state. (I.e.
# if it's started.) Admittedly, it's misnamed, but it should work.
# But we've already started the subtasks by the time we get here,
# so these values should already have been written. Too late.
# entry.task_output = InstructorTask.create_output_for_success(task_progress)
# entry.task_state = PROGRESS
# Weird. Note that by exiting this function successfully, will
# result in the AsyncResult for this task as being marked as SUCCESS.
# Below, we were just marking the entry to match. But it shouldn't
# match, if it's not really done.
pass
else:
entry.task_output = InstructorTask.create_output_for_success(task_progress)
entry.task_state = SUCCESS
entry.save_now()
except Exception:
# try to write out the failure to the entry before failing
_, exception, traceback = exc_info()
traceback_string = format_exc(traceback) if traceback is not None else ''
TASK_LOG.warning("background task (%s) failed: %s %s", task_id, exception, traceback_string)
entry.task_output = InstructorTask.create_output_for_failure(exception, traceback_string)
entry.task_state = FAILURE
entry.save_now()
raise
# Release any queries that the connection has been hanging onto: # Release any queries that the connection has been hanging onto:
reset_queries() reset_queries()
......
...@@ -69,7 +69,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -69,7 +69,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
'request_info': {}, 'request_info': {},
} }
def _run_task_with_mock_celery(self, task_function, entry_id, task_id, expected_failure_message=None): def _run_task_with_mock_celery(self, task_class, entry_id, task_id, expected_failure_message=None):
"""Submit a task and mock how celery provides a current_task.""" """Submit a task and mock how celery provides a current_task."""
self.current_task = Mock() self.current_task = Mock()
self.current_task.request = Mock() self.current_task.request = Mock()
...@@ -77,32 +77,34 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -77,32 +77,34 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
self.current_task.update_state = Mock() self.current_task.update_state = Mock()
if expected_failure_message is not None: if expected_failure_message is not None:
self.current_task.update_state.side_effect = TestTaskFailure(expected_failure_message) self.current_task.update_state.side_effect = TestTaskFailure(expected_failure_message)
task_args = [entry_id, self._get_xmodule_instance_args()]
with patch('instructor_task.tasks_helper._get_current_task') as mock_get_task: with patch('instructor_task.tasks_helper._get_current_task') as mock_get_task:
mock_get_task.return_value = self.current_task mock_get_task.return_value = self.current_task
return task_function(entry_id, self._get_xmodule_instance_args()) return task_class.apply(task_args, task_id=task_id).get()
def _test_missing_current_task(self, task_function): def _test_missing_current_task(self, task_class):
"""Check that a task_function fails when celery doesn't provide a current_task.""" """Check that a task_class fails when celery doesn't provide a current_task."""
task_entry = self._create_input_entry() task_entry = self._create_input_entry()
with self.assertRaises(UpdateProblemModuleStateError): with self.assertRaises(UpdateProblemModuleStateError):
task_function(task_entry.id, self._get_xmodule_instance_args()) task_class(task_entry.id, self._get_xmodule_instance_args())
def _test_undefined_course(self, task_function): def _test_undefined_course(self, task_class):
# run with celery, but no course defined # run with celery, but no course defined
task_entry = self._create_input_entry(course_id="bogus/course/id") task_entry = self._create_input_entry(course_id="bogus/course/id")
with self.assertRaises(ItemNotFoundError): with self.assertRaises(ItemNotFoundError):
self._run_task_with_mock_celery(task_function, task_entry.id, task_entry.task_id) self._run_task_with_mock_celery(task_class, task_entry.id, task_entry.task_id)
def _test_undefined_problem(self, task_function): def _test_undefined_problem(self, task_class):
"""Run with celery, but no problem defined.""" """Run with celery, but no problem defined."""
task_entry = self._create_input_entry() task_entry = self._create_input_entry()
with self.assertRaises(ItemNotFoundError): with self.assertRaises(ItemNotFoundError):
self._run_task_with_mock_celery(task_function, task_entry.id, task_entry.task_id) self._run_task_with_mock_celery(task_class, task_entry.id, task_entry.task_id)
def _test_run_with_task(self, task_function, action_name, expected_num_succeeded): def _test_run_with_task(self, task_class, action_name, expected_num_succeeded):
"""Run a task and check the number of StudentModules processed.""" """Run a task and check the number of StudentModules processed."""
task_entry = self._create_input_entry() task_entry = self._create_input_entry()
status = self._run_task_with_mock_celery(task_function, task_entry.id, task_entry.task_id) status = self._run_task_with_mock_celery(task_class, task_entry.id, task_entry.task_id)
# check return value # check return value
self.assertEquals(status.get('attempted'), expected_num_succeeded) self.assertEquals(status.get('attempted'), expected_num_succeeded)
self.assertEquals(status.get('succeeded'), expected_num_succeeded) self.assertEquals(status.get('succeeded'), expected_num_succeeded)
...@@ -114,10 +116,10 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -114,10 +116,10 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
self.assertEquals(json.loads(entry.task_output), status) self.assertEquals(json.loads(entry.task_output), status)
self.assertEquals(entry.task_state, SUCCESS) self.assertEquals(entry.task_state, SUCCESS)
def _test_run_with_no_state(self, task_function, action_name): def _test_run_with_no_state(self, task_class, action_name):
"""Run with no StudentModules defined for the current problem.""" """Run with no StudentModules defined for the current problem."""
self.define_option_problem(PROBLEM_URL_NAME) self.define_option_problem(PROBLEM_URL_NAME)
self._test_run_with_task(task_function, action_name, 0) self._test_run_with_task(task_class, action_name, 0)
def _create_students_with_state(self, num_students, state=None, grade=0, max_grade=1): def _create_students_with_state(self, num_students, state=None, grade=0, max_grade=1):
"""Create students, a problem, and StudentModule objects for testing""" """Create students, a problem, and StudentModule objects for testing"""
...@@ -145,12 +147,12 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -145,12 +147,12 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
state = json.loads(module.state) state = json.loads(module.state)
self.assertEquals(state['attempts'], num_attempts) self.assertEquals(state['attempts'], num_attempts)
def _test_run_with_failure(self, task_function, expected_message): def _test_run_with_failure(self, task_class, expected_message):
"""Run a task and trigger an artificial failure with the given message.""" """Run a task and trigger an artificial failure with the given message."""
task_entry = self._create_input_entry() task_entry = self._create_input_entry()
self.define_option_problem(PROBLEM_URL_NAME) self.define_option_problem(PROBLEM_URL_NAME)
with self.assertRaises(TestTaskFailure): with self.assertRaises(TestTaskFailure):
self._run_task_with_mock_celery(task_function, task_entry.id, task_entry.task_id, expected_message) self._run_task_with_mock_celery(task_class, task_entry.id, task_entry.task_id, expected_message)
# compare with entry in table: # compare with entry in table:
entry = InstructorTask.objects.get(id=task_entry.id) entry = InstructorTask.objects.get(id=task_entry.id)
self.assertEquals(entry.task_state, FAILURE) self.assertEquals(entry.task_state, FAILURE)
...@@ -158,7 +160,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -158,7 +160,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
self.assertEquals(output['exception'], 'TestTaskFailure') self.assertEquals(output['exception'], 'TestTaskFailure')
self.assertEquals(output['message'], expected_message) self.assertEquals(output['message'], expected_message)
def _test_run_with_long_error_msg(self, task_function): def _test_run_with_long_error_msg(self, task_class):
""" """
Run with an error message that is so long it will require Run with an error message that is so long it will require
truncation (as well as the jettisoning of the traceback). truncation (as well as the jettisoning of the traceback).
...@@ -167,7 +169,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -167,7 +169,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
self.define_option_problem(PROBLEM_URL_NAME) self.define_option_problem(PROBLEM_URL_NAME)
expected_message = "x" * 1500 expected_message = "x" * 1500
with self.assertRaises(TestTaskFailure): with self.assertRaises(TestTaskFailure):
self._run_task_with_mock_celery(task_function, task_entry.id, task_entry.task_id, expected_message) self._run_task_with_mock_celery(task_class, task_entry.id, task_entry.task_id, expected_message)
# compare with entry in table: # compare with entry in table:
entry = InstructorTask.objects.get(id=task_entry.id) entry = InstructorTask.objects.get(id=task_entry.id)
self.assertEquals(entry.task_state, FAILURE) self.assertEquals(entry.task_state, FAILURE)
...@@ -177,7 +179,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -177,7 +179,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
self.assertEquals(output['message'], expected_message[:len(output['message']) - 3] + "...") self.assertEquals(output['message'], expected_message[:len(output['message']) - 3] + "...")
self.assertTrue('traceback' not in output) self.assertTrue('traceback' not in output)
def _test_run_with_short_error_msg(self, task_function): def _test_run_with_short_error_msg(self, task_class):
""" """
Run with an error message that is short enough to fit Run with an error message that is short enough to fit
in the output, but long enough that the traceback won't. in the output, but long enough that the traceback won't.
...@@ -187,7 +189,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -187,7 +189,7 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
self.define_option_problem(PROBLEM_URL_NAME) self.define_option_problem(PROBLEM_URL_NAME)
expected_message = "x" * 900 expected_message = "x" * 900
with self.assertRaises(TestTaskFailure): with self.assertRaises(TestTaskFailure):
self._run_task_with_mock_celery(task_function, task_entry.id, task_entry.task_id, expected_message) self._run_task_with_mock_celery(task_class, task_entry.id, task_entry.task_id, expected_message)
# compare with entry in table: # compare with entry in table:
entry = InstructorTask.objects.get(id=task_entry.id) entry = InstructorTask.objects.get(id=task_entry.id)
self.assertEquals(entry.task_state, FAILURE) self.assertEquals(entry.task_state, FAILURE)
...@@ -198,33 +200,6 @@ class TestInstructorTasks(InstructorTaskModuleTestCase): ...@@ -198,33 +200,6 @@ class TestInstructorTasks(InstructorTaskModuleTestCase):
self.assertEquals(output['traceback'][-3:], "...") self.assertEquals(output['traceback'][-3:], "...")
class TestGeneralInstructorTask(TestInstructorTasks):
"""Tests instructor task mechanism using custom tasks"""
def test_successful_result_too_long(self):
# while we don't expect the existing tasks to generate output that is too
# long, we can test the framework will handle such an occurrence.
task_entry = self._create_input_entry()
self.define_option_problem(PROBLEM_URL_NAME)
action_name = 'x' * 1000
# define a custom task that does nothing:
update_fcn = lambda(_module_descriptor, _student_module): UPDATE_STATUS_SUCCEEDED
visit_fcn = partial(perform_module_state_update, update_fcn, None)
task_function = (lambda entry_id, xmodule_instance_args:
run_main_task(entry_id, visit_fcn, action_name))
# run the task:
with self.assertRaises(ValueError):
self._run_task_with_mock_celery(task_function, task_entry.id, task_entry.task_id)
# compare with entry in table:
entry = InstructorTask.objects.get(id=task_entry.id)
self.assertEquals(entry.task_state, FAILURE)
self.assertGreater(1023, len(entry.task_output))
output = json.loads(entry.task_output)
self.assertEquals(output['exception'], 'ValueError')
self.assertTrue("Length of task output is too long" in output['message'])
self.assertTrue('traceback' not in output)
class TestRescoreInstructorTask(TestInstructorTasks): class TestRescoreInstructorTask(TestInstructorTasks):
"""Tests problem-rescoring instructor task.""" """Tests problem-rescoring instructor task."""
...@@ -257,8 +232,8 @@ class TestRescoreInstructorTask(TestInstructorTasks): ...@@ -257,8 +232,8 @@ class TestRescoreInstructorTask(TestInstructorTasks):
task_entry = self._create_input_entry() task_entry = self._create_input_entry()
mock_instance = MagicMock() mock_instance = MagicMock()
del mock_instance.rescore_problem del mock_instance.rescore_problem
# TODO: figure out why this patch isn't working # TODO: figure out why this patch isn't working, when it seems to work fine for
# with patch('courseware.module_render.get_module_for_descriptor_internal') as mock_get_module: # the test_rescoring_success test below. Weird.
with patch('courseware.module_render.get_module_for_descriptor_internal') as mock_get_module: with patch('courseware.module_render.get_module_for_descriptor_internal') as mock_get_module:
mock_get_module.return_value = mock_instance mock_get_module.return_value = mock_instance
with self.assertRaises(UpdateProblemModuleStateError): with self.assertRaises(UpdateProblemModuleStateError):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment