Commit 615b560d by Will Daly

Add API calls for checking workflow completeness

Skip grading/training tasks if the workflow is already complete.
parent bd5c184f
......@@ -291,3 +291,65 @@ def create_classifiers(training_workflow_uuid, classifier_set):
).format(uuid=training_workflow_uuid, ex=ex)
logger.exception(msg)
raise AITrainingInternalError(msg)
def is_training_workflow_complete(workflow_uuid):
"""
Check whether the training workflow is complete.
Args:
workflow_uuid (str): The UUID of the training workflow
Returns:
bool
Raises:
AITrainingRequestError
AITrainingInternalError
"""
try:
return AITrainingWorkflow.is_workflow_complete(workflow_uuid)
except AITrainingWorkflow.DoesNotExist:
msg = (
u"Could not retrieve training workflow "
u"with uuid {uuid} to check whether it's complete."
).format(uuid=workflow_uuid)
raise AITrainingRequestError(msg)
except DatabaseError:
msg = (
u"An unexpected error occurred while checking "
u"the training workflow with uuid {uuid} for completeness"
).format(uuid=workflow_uuid)
raise AITrainingInternalError(msg)
def is_grading_workflow_complete(workflow_uuid):
"""
Check whether the grading workflow is complete.
Args:
workflow_uuid (str): The UUID of the grading workflow
Returns:
bool
Raises:
AIGradingRequestError
AIGradingInternalError
"""
try:
return AIGradingWorkflow.is_workflow_complete(workflow_uuid)
except AIGradingWorkflow.DoesNotExist:
msg = (
u"Could not retrieve grading workflow "
u"with uuid {uuid} to check whether it's complete."
).format(uuid=workflow_uuid)
raise AIGradingRequestError(msg)
except DatabaseError:
msg = (
u"An unexpected error occurred while checking "
u"the grading workflow with uuid {uuid} for completeness"
).format(uuid=workflow_uuid)
raise AIGradingInternalError(msg)
......@@ -481,6 +481,25 @@ class AIWorkflow(models.Model):
workflow = cls.objects.get(uuid=workflow_uuid)
yield workflow
@classmethod
def is_workflow_complete(cls, workflow_uuid):
"""
Check whether the workflow with a given UUID has been marked complete.
Args:
workflow_uuid (str): The UUID of the workflow to check.
Returns:
bool
Raises:
DatabaseError
cls.DoesNotExist
"""
workflow = cls.objects.get(uuid=workflow_uuid)
return workflow.is_complete
def _log_start_workflow(self):
"""
A logging operation called at the beginning of an AI Workflows life.
......
......@@ -188,6 +188,22 @@ class AIWorkerTrainingTest(CacheResetTest):
with self.assertRaises(AITrainingInternalError):
ai_worker_api.create_classifiers(workflow.uuid, CLASSIFIERS)
def test_is_workflow_complete(self):
self.assertFalse(ai_worker_api.is_training_workflow_complete(self.workflow_uuid))
workflow = AITrainingWorkflow.objects.get(uuid=self.workflow_uuid)
workflow.mark_complete_and_save()
self.assertTrue(ai_worker_api.is_training_workflow_complete(self.workflow_uuid))
def test_is_workflow_complete_no_such_workflow(self):
with self.assertRaises(AITrainingRequestError):
ai_worker_api.is_training_workflow_complete('no such workflow')
@mock.patch.object(AITrainingWorkflow.objects, 'get')
def test_is_workflow_complete_database_error(self, mock_call):
mock_call.side_effect = DatabaseError("Oh no!")
with self.assertRaises(AITrainingInternalError):
ai_worker_api.is_training_workflow_complete(self.workflow_uuid)
class AIWorkerGradingTest(CacheResetTest):
"""
......@@ -301,3 +317,19 @@ class AIWorkerGradingTest(CacheResetTest):
mock_call.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AIGradingInternalError):
ai_worker_api.create_assessment(self.workflow_uuid, self.SCORES)
def test_is_workflow_complete(self):
self.assertFalse(ai_worker_api.is_grading_workflow_complete(self.workflow_uuid))
workflow = AIGradingWorkflow.objects.get(uuid=self.workflow_uuid)
workflow.mark_complete_and_save()
self.assertTrue(ai_worker_api.is_grading_workflow_complete(self.workflow_uuid))
def test_is_workflow_complete_no_such_workflow(self):
with self.assertRaises(AIGradingRequestError):
ai_worker_api.is_grading_workflow_complete('no such workflow')
@mock.patch.object(AIGradingWorkflow.objects, 'get')
def test_is_workflow_complete_database_error(self, mock_call):
mock_call.side_effect = DatabaseError("Oh no!")
with self.assertRaises(AIGradingInternalError):
ai_worker_api.is_grading_workflow_complete(self.workflow_uuid)
......@@ -18,7 +18,9 @@ from openassessment.assessment.worker.algorithm import (
from openassessment.assessment.serializers import (
deserialize_training_examples, rubric_from_dict
)
from openassessment.assessment.errors import AITrainingRequestError, AIGradingInternalError
from openassessment.assessment.errors import (
AITrainingRequestError, AIGradingInternalError, AIGradingRequestError
)
from openassessment.assessment.test.constants import (
EXAMPLES, RUBRIC, STUDENT_ITEM, ANSWER
)
......@@ -135,6 +137,26 @@ class AITrainingTaskTest(CeleryTaskTest):
train_classifiers(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_skip_completed_workflow(self):
# Mark the grading workflow as complete
workflow = AITrainingWorkflow.objects.get(uuid=self.workflow_uuid)
workflow.mark_complete_and_save()
# The training task should short-circuit immediately, skipping calls
# to get parameters for the task.
actual_call = ai_worker_api.get_training_task_params
patched = 'openassessment.assessment.worker.grading.ai_worker_api.get_training_task_params'
with mock.patch(patched) as mock_call:
mock_call.side_effect = actual_call
train_classifiers(self.workflow_uuid)
self.assertFalse(mock_call.called)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_check_complete_error(self):
with self.assert_retry(train_classifiers, AITrainingRequestError):
train_classifiers("no such workflow uuid")
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_unable_to_load_algorithm_class(self):
# The algorithm is defined in the settings, but the class does not exist.
self._set_algorithm_id(UNDEFINED_CLASS_ALGORITHM_ID)
......@@ -266,6 +288,26 @@ class AIGradingTaskTest(CeleryTaskTest):
workflow.classifier_set = classifier_set
workflow.save()
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_skip_completed_workflow(self):
# Mark the grading workflow as complete
workflow = AIGradingWorkflow.objects.get(uuid=self.workflow_uuid)
workflow.mark_complete_and_save()
# The grading task should short-circuit immediately, skipping calls
# to get parameters for the task.
actual_call = ai_worker_api.get_grading_task_params
patched = 'openassessment.assessment.worker.grading.ai_worker_api.get_grading_task_params'
with mock.patch(patched) as mock_call:
mock_call.side_effect = actual_call
grade_essay(self.workflow_uuid)
self.assertFalse(mock_call.called)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_check_complete_error(self):
with self.assert_retry(grade_essay, AIGradingRequestError):
grade_essay("no such workflow uuid")
@mock.patch('openassessment.assessment.api.ai_worker.create_assessment')
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_algorithm_gives_invalid_score(self, mock_create_assessment):
......
......@@ -48,6 +48,21 @@ def grade_essay(workflow_uuid):
AIAlgorithmError: An error occurred while retrieving or using an AI algorithm.
"""
# Short-circuit if the workflow is already marked complete
# This is an optimization, but grading tasks could still
# execute multiple times depending on when they get picked
# up by workers and marked complete.
try:
if ai_worker_api.is_grading_workflow_complete(workflow_uuid):
return
except AIError:
msg = (
u"An unexpected error occurred while checking the "
u"completion of grading workflow with UUID {uuid}"
).format(uuid=workflow_uuid)
logger.exception(msg)
raise grade_essay.retry()
# Retrieve the task parameters
try:
params = ai_worker_api.get_grading_task_params(workflow_uuid)
......
......@@ -65,6 +65,21 @@ def train_classifiers(workflow_uuid):
InvalidExample: The training examples provided by the AI API were not valid.
"""
# Short-circuit if the workflow is already marked complete
# This is an optimization, but training tasks could still
# execute multiple times depending on when they get picked
# up by workers and marked complete.
try:
if ai_worker_api.is_training_workflow_complete(workflow_uuid):
return
except AIError:
msg = (
u"An unexpected error occurred while checking the "
u"completion of training workflow with UUID {uuid}"
).format(uuid=workflow_uuid)
logger.exception(msg)
raise train_classifiers.retry()
# Retrieve task parameters
try:
params = ai_worker_api.get_training_task_params(workflow_uuid)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment