Commit 04f90fee by Brian Wilson

Fix subtask code to handle (tests) running in eager mode.

parent a4c35ac4
......@@ -11,7 +11,7 @@ from itertools import cycle
from mock import patch, Mock
from smtplib import SMTPDataError, SMTPServerDisconnected
from celery.states import SUCCESS
from celery.states import SUCCESS, FAILURE
# from django.test.utils import override_settings
from django.conf import settings
......@@ -91,14 +91,40 @@ class TestBulkEmailInstructorTask(InstructorTaskCourseTestCase):
]
return students
def _test_run_with_task(self, task_class, action_name, total, succeeded, failed=0, skipped=0):
def _assert_single_subtask_status(self, entry, succeeded, failed=0, skipped=0, retried_nomax=0, retried_withmax=0):
"""Compare counts with 'subtasks' entry in InstructorTask table."""
subtask_info = json.loads(entry.subtasks)
# verify subtask-level counts:
self.assertEquals(subtask_info.get('total'), 1)
self.assertEquals(subtask_info.get('succeeded'), 1 if succeeded > 0 else 0)
self.assertEquals(subtask_info['failed'], 0 if succeeded > 0 else 1)
# self.assertEquals(subtask_info['retried'], retried_nomax + retried_withmax)
# verify individual subtask status:
subtask_status_info = subtask_info['status']
task_id_list = subtask_status_info.keys()
self.assertEquals(len(task_id_list), 1)
task_id = task_id_list[0]
subtask_status = subtask_status_info.get(task_id)
print("Testing subtask status: {}".format(subtask_status))
self.assertEquals(subtask_status['task_id'], task_id)
self.assertEquals(subtask_status['attempted'], succeeded + failed)
self.assertEquals(subtask_status['succeeded'], succeeded)
self.assertEquals(subtask_status['skipped'], skipped)
self.assertEquals(subtask_status['failed'], failed)
self.assertEquals(subtask_status['retried_nomax'], retried_nomax)
self.assertEquals(subtask_status['retried_withmax'], retried_withmax)
self.assertEquals(subtask_status['state'], SUCCESS if succeeded > 0 else FAILURE)
def _test_run_with_task(self, task_class, action_name, total, succeeded, failed=0, skipped=0, retried_nomax=0, retried_withmax=0):
"""Run a task and check the number of emails processed."""
task_entry = self._create_input_entry()
parent_status = self._run_task_with_mock_celery(task_class, task_entry.id, task_entry.task_id)
# check return value
self.assertEquals(parent_status.get('total'), total)
self.assertEquals(parent_status.get('action_name'), action_name)
# compare with entry in table:
# compare with task_output entry in InstructorTask table:
entry = InstructorTask.objects.get(id=task_entry.id)
status = json.loads(entry.task_output)
self.assertEquals(status.get('attempted'), succeeded + failed)
......@@ -109,9 +135,10 @@ class TestBulkEmailInstructorTask(InstructorTaskCourseTestCase):
self.assertEquals(status.get('action_name'), action_name)
self.assertGreater(status.get('duration_ms'), 0)
self.assertEquals(entry.task_state, SUCCESS)
self._assert_single_subtask_status(entry, succeeded, failed, skipped, retried_nomax, retried_withmax)
def test_successful(self):
num_students = settings.EMAILS_PER_TASK
num_students = settings.EMAILS_PER_TASK - 1
self._create_students(num_students)
# we also send email to the instructor:
num_emails = num_students + 1
......@@ -119,9 +146,9 @@ class TestBulkEmailInstructorTask(InstructorTaskCourseTestCase):
get_conn.return_value.send_messages.side_effect = cycle([None])
self._test_run_with_task(send_bulk_course_email, 'emailed', num_emails, num_emails)
def test_data_err_fail(self):
def test_smtp_blacklisted_user(self):
# Test that celery handles permanent SMTPDataErrors by failing and not retrying.
num_students = settings.EMAILS_PER_TASK
num_students = settings.EMAILS_PER_TASK - 1
self._create_students(num_students)
# we also send email to the instructor:
num_emails = num_students + 1
......@@ -144,19 +171,31 @@ class TestBulkEmailInstructorTask(InstructorTaskCourseTestCase):
with patch('bulk_email.tasks.get_connection', autospec=True) as get_conn:
# have every other mail attempt fail due to disconnection:
get_conn.return_value.send_messages.side_effect = cycle([SMTPServerDisconnected(425, "Disconnecting"), None])
self._test_run_with_task(send_bulk_course_email, 'emailed', num_emails, expected_succeeds, failed=expected_fails)
def test_max_retry(self):
self._test_run_with_task(
send_bulk_course_email,
'emailed',
num_emails,
expected_succeeds,
failed=expected_fails,
retried_withmax=num_emails
)
def test_max_retry_limit_causes_failure(self):
# Test that celery can hit a maximum number of retries.
num_students = 1
self._create_students(num_students)
# we also send email to the instructor:
num_emails = num_students + 1
# This is an ugly hack: the failures that are reported by the EAGER version of retry
# are multiplied by the attempted number of retries (equals max plus one).
expected_fails = num_emails * (settings.BULK_EMAIL_MAX_RETRIES + 1)
expected_fails = num_emails
expected_succeeds = 0
with patch('bulk_email.tasks.get_connection', autospec=True) as get_conn:
# always fail to connect, triggering repeated retries until limit is hit:
get_conn.return_value.send_messages.side_effect = cycle([SMTPServerDisconnected(425, "Disconnecting")])
self._test_run_with_task(send_bulk_course_email, 'emailed', num_emails, expected_succeeds, failed=expected_fails)
self._test_run_with_task(
send_bulk_course_email,
'emailed',
num_emails,
expected_succeeds,
failed=expected_fails,
retried_withmax=(settings.BULK_EMAIL_MAX_RETRIES + 1)
)
......@@ -87,6 +87,13 @@ def increment_subtask_status(subtask_result, succeeded=0, failed=0, skipped=0, r
return new_result
def _get_retry_count(subtask_result):
"""Return the number of retries counted for the given subtask."""
retry_count = subtask_result.get('retried_nomax', 0)
retry_count += subtask_result.get('retried_withmax', 0)
return retry_count
def update_instructor_task_for_subtasks(entry, action_name, total_num, subtask_id_list):
"""
Store initial subtask information to InstructorTask object.
......@@ -138,7 +145,6 @@ def update_instructor_task_for_subtasks(entry, action_name, total_num, subtask_i
'total': num_subtasks,
'succeeded': 0,
'failed': 0,
'retried': 0,
'status': subtask_status
}
entry.subtasks = json.dumps(subtask_dict)
......@@ -190,18 +196,36 @@ def update_subtask_status(entry_id, current_task_id, new_subtask_status):
TASK_LOG.warning(msg)
raise ValueError(msg)
# Check for race condition where a subtask which has been retried
# has the retry already write its results here before the code
# that was invoking the retry has had a chance to update this status.
# While we think this is highly unlikely in production code, it is
# the norm in "eager" mode (used by tests) where the retry is called
# and run to completion before control is returned to the code that
# invoked the retry.
current_subtask_status = subtask_status_info[current_task_id]
current_retry_count = _get_retry_count(current_subtask_status)
new_retry_count = _get_retry_count(new_subtask_status)
if current_retry_count > new_retry_count:
TASK_LOG.warning("Task id %s: Retry %s has already updated InstructorTask -- skipping update for retry %s.",
current_task_id, current_retry_count, new_retry_count)
transaction.rollback()
return
elif new_retry_count > 0:
TASK_LOG.debug("Task id %s: previous retry %s is not newer -- applying update for retry %s.",
current_task_id, current_retry_count, new_retry_count)
# Update status unless it has already been set. This can happen
# when a task is retried and running in eager mode -- the retries
# will be updating before the original call, and we don't want their
# ultimate status to be clobbered by the "earlier" updates. This
# should not be a problem in normal (non-eager) processing.
current_subtask_status = subtask_status_info[current_task_id]
current_state = current_subtask_status['state']
new_state = new_subtask_status['state']
if new_state != RETRY or current_state == QUEUING or current_state in READY_STATES:
if new_state != RETRY or current_state not in READY_STATES:
subtask_status_info[current_task_id] = new_subtask_status
# Update the parent task progress
# Update the parent task progress.
# Set the estimate of duration, but only if it
# increases. Clock skew between time() returned by different machines
# may result in non-monotonic values for duration.
......@@ -224,9 +248,7 @@ def update_subtask_status(entry_id, current_task_id, new_subtask_status):
# entire new_subtask_status dict.
if new_state == SUCCESS:
subtask_dict['succeeded'] += 1
elif new_state == RETRY:
subtask_dict['retried'] += 1
else:
elif new_state in READY_STATES:
subtask_dict['failed'] += 1
num_remaining = subtask_dict['total'] - subtask_dict['succeeded'] - subtask_dict['failed']
......@@ -246,6 +268,7 @@ def update_subtask_status(entry_id, current_task_id, new_subtask_status):
except Exception:
TASK_LOG.exception("Unexpected error while updating InstructorTask.")
transaction.rollback()
raise
else:
TASK_LOG.debug("about to commit....")
transaction.commit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment