Commit bedbab88 by Will Daly

Add tests for AI grading API calls

parent bd087b5f
......@@ -68,7 +68,7 @@ def get_grading_task_params(grading_workflow_uuid):
'classifier_set': classifier_set.classifiers_dict,
'algorithm_id': workflow.algorithm_id,
}
except (ValueError, IOError, DatabaseError) as ex:
except Exception as ex:
msg = (
u"An unexpected error occurred while retrieving "
u"classifiers for the grading workflow with UUID {uuid}: {ex}"
......@@ -253,10 +253,3 @@ def create_classifiers(training_workflow_uuid, classifier_set):
).format(uuid=training_workflow_uuid, ex=ex)
logger.exception(msg)
raise AITrainingInternalError(msg)
except DatabaseError:
msg = (
u"An unexpected error occurred while creating the classifier set "
u"for the AI training workflow with UUID {}"
).format(training_workflow_uuid)
logger.exception(msg)
raise AITrainingInternalError(msg)
......@@ -9,10 +9,15 @@ from django.test.utils import override_settings
from openassessment.test_utils import CacheResetTest
from submissions import api as sub_api
from openassessment.assessment.api import ai as ai_api
from openassessment.assessment.models import AITrainingWorkflow, AIClassifierSet
from openassessment.assessment.models import (
AITrainingWorkflow, AIGradingWorkflow, AIClassifierSet, Assessment
)
from openassessment.assessment.worker.algorithm import AIAlgorithm
from openassessment.assessment.serializers import rubric_from_dict
from openassessment.assessment.errors import AITrainingRequestError, AITrainingInternalError
from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError,
AIGradingRequestError, AIGradingInternalError
)
from openassessment.assessment.test.constants import RUBRIC, EXAMPLES, STUDENT_ITEM, ANSWER
......@@ -53,7 +58,7 @@ class StubAIAlgorithm(AIAlgorithm):
ALGORITHM_ID = "test-stub"
AI_ALGORITHMS = {
ALGORITHM_ID: '{module}.StubAIAlgorithm'.format(module=__name__)
ALGORITHM_ID: '{module}.StubAIAlgorithm'.format(module=__name__),
}
......@@ -179,3 +184,56 @@ class AIGradingTest(CacheResetTest):
criterion_name = part['option']['criterion']['name']
expected_score = self.CLASSIFIER_SCORE_OVERRIDES[criterion_name]['score_override']
self.assertEqual(part['option']['points'], expected_score)
@mock.patch('openassessment.assessment.api.ai.grading_tasks.grade_essay')
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_submit_no_classifiers_available(self, mock_task):
# Use a rubric that does not have classifiers available
new_rubric = copy.deepcopy(RUBRIC)
new_rubric['criteria'] = new_rubric['criteria'][1:]
# Submit the essay -- since there are no classifiers available,
# the workflow should be created, but no task should be scheduled.
workflow_uuid = ai_api.submit(self.submission_uuid, new_rubric, ALGORITHM_ID)
# Verify that the workflow was created with a null classifier set
workflow = AIGradingWorkflow.objects.get(uuid=workflow_uuid)
self.assertIs(workflow.classifier_set, None)
# Verify that there are no assessments
latest_assessment = ai_api.get_latest_assessment(self.submission_uuid)
self.assertIs(latest_assessment, None)
# Verify that the task was never scheduled
self.assertFalse(mock_task.apply_async.called)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_submit_submission_not_found(self):
with self.assertRaises(AIGradingRequestError):
ai_api.submit("no_such_submission", RUBRIC, ALGORITHM_ID)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_submit_invalid_rubric(self):
invalid_rubric = {'not_valid': True}
with self.assertRaises(AIGradingRequestError):
ai_api.submit(self.submission_uuid, invalid_rubric, ALGORITHM_ID)
@mock.patch.object(AIGradingWorkflow.objects, 'create')
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_submit_database_error(self, mock_call):
mock_call.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AIGradingInternalError):
ai_api.submit(self.submission_uuid, RUBRIC, ALGORITHM_ID)
@mock.patch('openassessment.assessment.api.ai.grading_tasks.grade_essay')
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_grade_task_schedule_error(self, mock_task):
mock_task.apply_async.side_effect = IOError("Test error!")
with self.assertRaises(AIGradingInternalError):
ai_api.submit(self.submission_uuid, RUBRIC, ALGORITHM_ID)
@mock.patch.object(Assessment.objects, 'filter')
def test_get_latest_assessment_database_error(self, mock_call):
mock_call.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AIGradingInternalError):
ai_api.get_latest_assessment(self.submission_uuid)
......@@ -4,16 +4,43 @@ Tests for AI worker API calls.
"""
import copy
import datetime
from uuid import uuid4
import mock
from django.db import DatabaseError
from django.core.files.base import ContentFile
from submissions import api as sub_api
from openassessment.test_utils import CacheResetTest
from openassessment.assessment.api import ai_worker as ai_worker_api
from openassessment.assessment.models import AITrainingWorkflow, AIClassifier
from openassessment.assessment.serializers import deserialize_training_examples
from openassessment.assessment.models import (
AITrainingWorkflow, AIGradingWorkflow,
AIClassifier, AIClassifierSet, Assessment
)
from openassessment.assessment.serializers import (
rubric_from_dict, deserialize_training_examples
)
from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError
AITrainingRequestError, AITrainingInternalError,
AIGradingRequestError, AIGradingInternalError
)
from openassessment.assessment.test.constants import (
EXAMPLES, RUBRIC, STUDENT_ITEM, ANSWER
)
from openassessment.assessment.test.constants import EXAMPLES, RUBRIC
ALGORITHM_ID = "test-algorithm"
# Classifier data
# Since this is controlled by the AI algorithm implementation,
# we could put anything here as long as it's JSON-serializable.
CLASSIFIERS = {
u"vøȼȺƀᵾłȺɍɏ": {
'name': u'𝒕𝒆𝒔𝒕 𝒄𝒍𝒂𝒔𝒔𝒊𝒇𝒊𝒆𝒓',
'data': u'Öḧ ḷëẗ ẗḧë ṡüṅ ḅëäẗ ḋöẅṅ üṗöṅ ṁÿ ḟäċë, ṡẗäṛṡ ẗö ḟïḷḷ ṁÿ ḋṛëäṁ"'
},
u"ﻭɼค๓๓คɼ": {
'name': u'𝒕𝒆𝒔𝒕 𝒄𝒍𝒂𝒔𝒔𝒊𝒇𝒊𝒆𝒓',
'data': u"І ам а тѓаvэlэѓ оf ъотЂ тімэ аиↁ ѕрасэ, то ъэ шЂэѓэ І Ђаvэ ъээи"
}
}
class AIWorkerTrainingTest(CacheResetTest):
......@@ -21,28 +48,12 @@ class AIWorkerTrainingTest(CacheResetTest):
Tests for the AI API calls a worker would make when
completing a training task.
"""
ALGORITHM_ID = "test-algorithm"
# Classifier data
# Since this is controlled by the AI algorithm implementation,
# we could put anything here as long as it's JSON-serializable.
CLASSIFIERS = {
u"vøȼȺƀᵾłȺɍɏ": {
'name': u'𝒕𝒆𝒔𝒕 𝒄𝒍𝒂𝒔𝒔𝒊𝒇𝒊𝒆𝒓',
'data': u'Öḧ ḷëẗ ẗḧë ṡüṅ ḅëäẗ ḋöẅṅ üṗöṅ ṁÿ ḟäċë, ṡẗäṛṡ ẗö ḟïḷḷ ṁÿ ḋṛëäṁ"'
},
u"ﻭɼค๓๓คɼ": {
'name': u'𝒕𝒆𝒔𝒕 𝒄𝒍𝒂𝒔𝒔𝒊𝒇𝒊𝒆𝒓',
'data': u"І ам а тѓаvэlэѓ оf ъотЂ тімэ аиↁ ѕрасэ, то ъэ шЂэѓэ І Ђаvэ ъээи"
}
}
def setUp(self):
"""
Create a training workflow in the database.
"""
examples = deserialize_training_examples(EXAMPLES, RUBRIC)
workflow = AITrainingWorkflow.start_workflow(examples, self.ALGORITHM_ID)
workflow = AITrainingWorkflow.start_workflow(examples, ALGORITHM_ID)
self.workflow_uuid = workflow.uuid
def test_get_training_task_params(self):
......@@ -64,7 +75,7 @@ class AIWorkerTrainingTest(CacheResetTest):
},
]
self.assertItemsEqual(params['training_examples'], expected_examples)
self.assertItemsEqual(params['algorithm_id'], self.ALGORITHM_ID)
self.assertItemsEqual(params['algorithm_id'], ALGORITHM_ID)
def test_get_training_task_params_no_workflow(self):
with self.assertRaises(AITrainingRequestError):
......@@ -77,7 +88,7 @@ class AIWorkerTrainingTest(CacheResetTest):
ai_worker_api.get_training_task_params(self.workflow_uuid)
def test_create_classifiers(self):
ai_worker_api.create_classifiers(self.workflow_uuid, self.CLASSIFIERS)
ai_worker_api.create_classifiers(self.workflow_uuid, CLASSIFIERS)
# Expect that the workflow was marked complete
workflow = AITrainingWorkflow.objects.get(uuid=self.workflow_uuid)
......@@ -86,21 +97,21 @@ class AIWorkerTrainingTest(CacheResetTest):
# Expect that the classifier set was created with the correct data
self.assertIsNot(workflow.classifier_set, None)
saved_classifiers = workflow.classifier_set.classifiers_dict
self.assertItemsEqual(self.CLASSIFIERS, saved_classifiers)
self.assertItemsEqual(CLASSIFIERS, saved_classifiers)
def test_create_classifiers_no_workflow(self):
with self.assertRaises(AITrainingRequestError):
ai_worker_api.create_classifiers("invalid_uuid", self.CLASSIFIERS)
ai_worker_api.create_classifiers("invalid_uuid", CLASSIFIERS)
@mock.patch.object(AITrainingWorkflow.objects, 'get')
def test_create_classifiers_database_error(self, mock_get):
mock_get.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AITrainingInternalError):
ai_worker_api.create_classifiers(self.workflow_uuid, self.CLASSIFIERS)
ai_worker_api.create_classifiers(self.workflow_uuid, CLASSIFIERS)
def test_create_classifiers_serialize_error(self):
# Mutate the classifier data so it is NOT JSON-serializable
classifiers = copy.deepcopy(self.CLASSIFIERS)
classifiers = copy.deepcopy(CLASSIFIERS)
classifiers[u"vøȼȺƀᵾłȺɍɏ"] = datetime.datetime.now()
# Expect an error when we try to create the classifiers
......@@ -109,7 +120,7 @@ class AIWorkerTrainingTest(CacheResetTest):
def test_create_classifiers_missing_criteria(self):
# Remove a criterion from the classifiers dict
classifiers = copy.deepcopy(self.CLASSIFIERS)
classifiers = copy.deepcopy(CLASSIFIERS)
del classifiers[u"vøȼȺƀᵾłȺɍɏ"]
# Expect an error when we try to create the classifiers
......@@ -118,7 +129,7 @@ class AIWorkerTrainingTest(CacheResetTest):
def test_create_classifiers_unrecognized_criterion(self):
# Add an extra criterion to the classifiers dict
classifiers = copy.deepcopy(self.CLASSIFIERS)
classifiers = copy.deepcopy(CLASSIFIERS)
classifiers[u"extra_criterion"] = copy.deepcopy(classifiers[u"vøȼȺƀᵾłȺɍɏ"])
# Expect an error when we try to create the classifiers
......@@ -130,14 +141,14 @@ class AIWorkerTrainingTest(CacheResetTest):
# Simulate an error occurring when uploading the trained classifier
mock_data.save.side_effect = IOError("OH NO!!!")
with self.assertRaises(AITrainingInternalError):
ai_worker_api.create_classifiers(self.workflow_uuid, self.CLASSIFIERS)
ai_worker_api.create_classifiers(self.workflow_uuid, CLASSIFIERS)
def test_create_classifiers_twice(self):
# Simulate repeated task execution for the same workflow
# Since these are executed sequentially, the second call should
# have no effect.
ai_worker_api.create_classifiers(self.workflow_uuid, self.CLASSIFIERS)
ai_worker_api.create_classifiers(self.workflow_uuid, self.CLASSIFIERS)
ai_worker_api.create_classifiers(self.workflow_uuid, CLASSIFIERS)
ai_worker_api.create_classifiers(self.workflow_uuid, CLASSIFIERS)
# Expect that the workflow was marked complete
workflow = AITrainingWorkflow.objects.get(uuid=self.workflow_uuid)
......@@ -146,12 +157,113 @@ class AIWorkerTrainingTest(CacheResetTest):
# Expect that the classifier set was created with the correct data
self.assertIsNot(workflow.classifier_set, None)
saved_classifiers = workflow.classifier_set.classifiers_dict
self.assertItemsEqual(self.CLASSIFIERS, saved_classifiers)
self.assertItemsEqual(CLASSIFIERS, saved_classifiers)
def test_create_classifiers_no_training_examples(self):
# Create a workflow with no training examples
workflow = AITrainingWorkflow.objects.create(algorithm_id=self.ALGORITHM_ID)
workflow = AITrainingWorkflow.objects.create(algorithm_id=ALGORITHM_ID)
# Expect an error when we try to create classifiers
with self.assertRaises(AITrainingInternalError):
ai_worker_api.create_classifiers(workflow.uuid, self.CLASSIFIERS)
ai_worker_api.create_classifiers(workflow.uuid, CLASSIFIERS)
class AIWorkerGradingTest(CacheResetTest):
"""
Tests for the AI API calls a worker would make when
completing a grading task.
"""
SCORES = {
u"vøȼȺƀᵾłȺɍɏ": 1,
u"ﻭɼค๓๓คɼ": 0
}
def setUp(self):
"""
Create a grading workflow in the database.
"""
# Create a submission
submission = sub_api.create_submission(STUDENT_ITEM, ANSWER)
self.submission_uuid = submission['uuid']
# Create a workflow for the submission
workflow = AIGradingWorkflow.start_workflow(self.submission_uuid, RUBRIC, ALGORITHM_ID)
self.workflow_uuid = workflow.uuid
# Associate the workflow with classifiers
rubric = rubric_from_dict(RUBRIC)
classifier_set = AIClassifierSet.create_classifier_set(
CLASSIFIERS, rubric, ALGORITHM_ID
)
workflow.classifier_set = classifier_set
workflow.save()
def test_get_grading_task_params(self):
params = ai_worker_api.get_grading_task_params(self.workflow_uuid)
expected_params = {
'essay_text': ANSWER,
'classifier_set': CLASSIFIERS,
'algorithm_id': ALGORITHM_ID
}
self.assertItemsEqual(params, expected_params)
def test_get_grading_task_params_no_workflow(self):
with self.assertRaises(AIGradingRequestError):
ai_worker_api.get_grading_task_params("invalid_uuid")
def test_get_grading_task_params_no_classifiers(self):
# Remove the classifiers from the workflow
workflow = AIGradingWorkflow.objects.get(uuid=self.workflow_uuid)
workflow.classifier_set = None
workflow.save()
# Should get an error when retrieving task params
with self.assertRaises(AIGradingInternalError):
ai_worker_api.get_grading_task_params(self.workflow_uuid)
@mock.patch.object(AIGradingWorkflow.objects, 'get')
def test_get_grading_task_params_database_error(self, mock_call):
mock_call.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AIGradingInternalError):
ai_worker_api.get_grading_task_params(self.submission_uuid)
def test_invalid_classifier_data(self):
# Modify the classifier data so it is not valid JSON
invalid_json = "{"
for classifier in AIClassifier.objects.all():
classifier.classifier_data.save(uuid4().hex, ContentFile(invalid_json))
# Should get an error when retrieving task params
with self.assertRaises(AIGradingInternalError):
ai_worker_api.get_grading_task_params(self.workflow_uuid)
def test_create_assessment(self):
ai_worker_api.create_assessment(self.workflow_uuid, self.SCORES)
assessment = Assessment.objects.get(submission_uuid=self.submission_uuid)
self.assertEqual(assessment.points_earned, 1)
def test_create_assessment_no_workflow(self):
with self.assertRaises(AIGradingRequestError):
ai_worker_api.create_assessment("invalid_uuid", self.SCORES)
def test_create_assessment_workflow_already_complete(self):
# Try to create assessments for the same workflow multiple times
ai_worker_api.create_assessment(self.workflow_uuid, self.SCORES)
ai_worker_api.create_assessment(self.workflow_uuid, self.SCORES)
# Expect that only one assessment is created for the submission
num_assessments = Assessment.objects.filter(submission_uuid=self.submission_uuid).count()
self.assertEqual(num_assessments, 1)
@mock.patch.object(AIGradingWorkflow.objects, 'get')
def test_create_assessment_database_error_retrieving_workflow(self, mock_call):
mock_call.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AIGradingInternalError):
ai_worker_api.create_assessment(self.workflow_uuid, self.SCORES)
@mock.patch.object(Assessment.objects, 'create')
def test_create_assessment_database_error_complete_workflow(self, mock_call):
mock_call.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AIGradingInternalError):
ai_worker_api.create_assessment(self.workflow_uuid, self.SCORES)
......@@ -5,16 +5,22 @@ Tests for AI worker tasks.
from contextlib import contextmanager
import mock
from django.test.utils import override_settings
from submissions import api as sub_api
from openassessment.test_utils import CacheResetTest
from openassessment.assessment.worker.training import train_classifiers, InvalidExample
from openassessment.assessment.worker.grading import grade_essay
from openassessment.assessment.api import ai_worker as ai_worker_api
from openassessment.assessment.models import AITrainingWorkflow
from openassessment.assessment.models import AITrainingWorkflow, AIGradingWorkflow, AIClassifierSet
from openassessment.assessment.worker.algorithm import (
AIAlgorithm, UnknownAlgorithm, AlgorithmLoadError, TrainingError
AIAlgorithm, UnknownAlgorithm, AlgorithmLoadError, TrainingError, ScoreError
)
from openassessment.assessment.serializers import (
deserialize_training_examples, rubric_from_dict
)
from openassessment.assessment.errors import AITrainingRequestError, AIGradingInternalError
from openassessment.assessment.test.constants import (
EXAMPLES, RUBRIC, STUDENT_ITEM, ANSWER
)
from openassessment.assessment.serializers import deserialize_training_examples
from openassessment.assessment.errors import AITrainingRequestError
from openassessment.assessment.test.constants import EXAMPLES, RUBRIC
class StubAIAlgorithm(AIAlgorithm):
......@@ -25,7 +31,7 @@ class StubAIAlgorithm(AIAlgorithm):
return {}
def score(self, text, classifier):
raise NotImplementedError
return 0
class ErrorStubAIAlgorithm(AIAlgorithm):
......@@ -36,58 +42,87 @@ class ErrorStubAIAlgorithm(AIAlgorithm):
raise TrainingError("Test error!")
def score(self, text, classifier):
raise NotImplementedError
raise ScoreError("Test error!")
ALGORITHM_ID = u"test-stub"
ERROR_STUB_ALGORITHM_ID = u"error-stub"
UNDEFINED_CLASS_ALGORITHM_ID = u"undefined_class"
UNDEFINED_MODULE_ALGORITHM_ID = u"undefined_module"
AI_ALGORITHMS = {
ALGORITHM_ID: '{module}.StubAIAlgorithm'.format(module=__name__),
ERROR_STUB_ALGORITHM_ID: '{module}.ErrorStubAIAlgorithm'.format(module=__name__),
UNDEFINED_CLASS_ALGORITHM_ID: '{module}.NotDefinedAIAlgorithm'.format(module=__name__),
UNDEFINED_MODULE_ALGORITHM_ID: 'openassessment.not.valid.NotDefinedAIAlgorithm'
}
class AITrainingTaskTest(CacheResetTest):
class CeleryTaskTest(CacheResetTest):
"""
Tests for the training task executed asynchronously by Celery workers.
Test case for Celery tasks.
"""
@contextmanager
def assert_retry(self, task, final_exception):
"""
Context manager that asserts that the training task was retried.
ALGORITHM_ID = u"test-stub"
ERROR_STUB_ALGORITHM_ID = u"error-stub"
UNDEFINED_CLASS_ALGORITHM_ID = u"undefined_class"
UNDEFINED_MODULE_ALGORITHM_ID = u"undefined_module"
AI_ALGORITHMS = {
ALGORITHM_ID: '{module}.StubAIAlgorithm'.format(module=__name__),
ERROR_STUB_ALGORITHM_ID: '{module}.ErrorStubAIAlgorithm'.format(module=__name__),
UNDEFINED_CLASS_ALGORITHM_ID: '{module}.NotDefinedAIAlgorithm'.format(module=__name__),
UNDEFINED_MODULE_ALGORITHM_ID: 'openassessment.not.valid.NotDefinedAIAlgorithm'
}
Args:
task (celery.app.task.Task): The Celery task object.
final_exception (Exception): The error thrown after retrying.
Raises:
AssertionError
"""
original_retry = task.retry
task.retry = mock.MagicMock()
task.retry.side_effect = lambda: original_retry(task)
try:
with self.assertRaises(final_exception):
yield
task.retry.assert_called_once()
finally:
task.retry = original_retry
class AITrainingTaskTest(CeleryTaskTest):
"""
Tests for the training task executed asynchronously by Celery workers.
"""
def setUp(self):
"""
Create a training workflow in the database.
"""
examples = deserialize_training_examples(EXAMPLES, RUBRIC)
workflow = AITrainingWorkflow.start_workflow(examples, self.ALGORITHM_ID)
workflow = AITrainingWorkflow.start_workflow(examples, ALGORITHM_ID)
self.workflow_uuid = workflow.uuid
def test_unknown_algorithm(self):
# Since we haven't overridden settings to configure the algorithms,
# the worker will not recognize the workflow's algorithm ID.
with self._assert_retry(train_classifiers, UnknownAlgorithm):
with self.assert_retry(train_classifiers, UnknownAlgorithm):
train_classifiers(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_unable_to_load_algorithm_class(self):
# The algorithm is defined in the settings, but the class does not exist.
self._set_algorithm_id(self.UNDEFINED_CLASS_ALGORITHM_ID)
with self._assert_retry(train_classifiers, AlgorithmLoadError):
self._set_algorithm_id(UNDEFINED_CLASS_ALGORITHM_ID)
with self.assert_retry(train_classifiers, AlgorithmLoadError):
train_classifiers(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_unable_to_find_algorithm_module(self):
# The algorithm is defined in the settings, but the module can't be loaded
self._set_algorithm_id(self.UNDEFINED_MODULE_ALGORITHM_ID)
with self._assert_retry(train_classifiers, AlgorithmLoadError):
self._set_algorithm_id(UNDEFINED_MODULE_ALGORITHM_ID)
with self.assert_retry(train_classifiers, AlgorithmLoadError):
train_classifiers(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch('openassessment.assessment.worker.training.ai_worker_api.get_training_task_params')
def test_get_training_task_params_api_error(self, mock_call):
mock_call.side_effect = AITrainingRequestError("Test error!")
with self._assert_retry(train_classifiers, AITrainingRequestError):
with self.assert_retry(train_classifiers, AITrainingRequestError):
train_classifiers(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
......@@ -111,15 +146,15 @@ class AITrainingTaskTest(CacheResetTest):
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_training_algorithm_error(self):
# Use a stub algorithm implementation that raises an exception during training
self._set_algorithm_id(self.ERROR_STUB_ALGORITHM_ID)
with self._assert_retry(train_classifiers, TrainingError):
self._set_algorithm_id(ERROR_STUB_ALGORITHM_ID)
with self.assert_retry(train_classifiers, TrainingError):
train_classifiers(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch('openassessment.assessment.worker.training.ai_worker_api.create_classifiers')
def test_create_classifiers_api_error(self, mock_call):
mock_call.side_effect = AITrainingRequestError("Test error!")
with self._assert_retry(train_classifiers, AITrainingRequestError):
with self.assert_retry(train_classifiers, AITrainingRequestError):
train_classifiers(self.workflow_uuid)
def _set_algorithm_id(self, algorithm_id):
......@@ -158,28 +193,85 @@ class AITrainingTaskTest(CacheResetTest):
call_signature = 'openassessment.assessment.worker.training.ai_worker_api.get_training_task_params'
with mock.patch(call_signature) as mock_call:
mock_call.return_value = params
with self._assert_retry(train_classifiers, InvalidExample):
with self.assert_retry(train_classifiers, InvalidExample):
train_classifiers(self.workflow_uuid)
@contextmanager
def _assert_retry(self, task, final_exception):
class AIGradingTaskTest(CeleryTaskTest):
"""
Tests for the grading task executed asynchronously by Celery workers.
"""
# Classifier data
# Since this is controlled by the AI algorithm implementation,
# we could put anything here as long as it's JSON-serializable.
CLASSIFIERS = {
u"vøȼȺƀᵾłȺɍɏ": {
'name': u'𝒕𝒆𝒔𝒕 𝒄𝒍𝒂𝒔𝒔𝒊𝒇𝒊𝒆𝒓',
'data': u'Öḧ ḷëẗ ẗḧë ṡüṅ ḅëäẗ ḋöẅṅ üṗöṅ ṁÿ ḟäċë, ṡẗäṛṡ ẗö ḟïḷḷ ṁÿ ḋṛëäṁ"'
},
u"ﻭɼค๓๓คɼ": {
'name': u'𝒕𝒆𝒔𝒕 𝒄𝒍𝒂𝒔𝒔𝒊𝒇𝒊𝒆𝒓',
'data': u"І ам а тѓаvэlэѓ оf ъотЂ тімэ аиↁ ѕрасэ, то ъэ шЂэѓэ І Ђаvэ ъээи"
}
}
def setUp(self):
"""
Context manager that asserts that the training task was retried.
Create a submission and grading workflow.
"""
# Create a submission
submission = sub_api.create_submission(STUDENT_ITEM, ANSWER)
self.submission_uuid = submission['uuid']
# Create a workflow for the submission
workflow = AIGradingWorkflow.start_workflow(self.submission_uuid, RUBRIC, ALGORITHM_ID)
self.workflow_uuid = workflow.uuid
# Associate the workflow with classifiers
rubric = rubric_from_dict(RUBRIC)
classifier_set = AIClassifierSet.create_classifier_set(
self.CLASSIFIERS, rubric, ALGORITHM_ID
)
workflow.classifier_set = classifier_set
workflow.save()
@mock.patch('openassessment.assessment.worker.grading.ai_worker_api.get_grading_task_params')
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_retrieve_params_error(self, mock_call):
mock_call.side_effect = AIGradingInternalError("Test error")
with self.assert_retry(grade_essay, AIGradingInternalError):
grade_essay(self.workflow_uuid)
def test_unknown_algorithm_id_error(self):
# Since we're not overriding settings, the algorithm ID won't be recognized
with self.assert_retry(grade_essay, UnknownAlgorithm):
grade_essay(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_algorithm_score_error(self):
self._set_algorithm_id(ERROR_STUB_ALGORITHM_ID)
with self.assert_retry(grade_essay, ScoreError):
grade_essay(self.workflow_uuid)
@mock.patch('openassessment.assessment.worker.grading.ai_worker_api.create_assessment')
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_create_assessment_error(self, mock_call):
mock_call.side_effect = AIGradingInternalError
with self.assert_retry(grade_essay, AIGradingInternalError):
grade_essay(self.workflow_uuid)
def _set_algorithm_id(self, algorithm_id):
"""
Override the default algorithm ID for the grading workflow.
Args:
task (celery.app.task.Task): The Celery task object.
final_exception (Exception): The error thrown after retrying.
algorithm_id (unicode): The new algorithm ID
Raises:
AssertionError
Returns:
None
"""
original_retry = task.retry
task.retry = mock.MagicMock()
task.retry.side_effect = lambda: original_retry(task)
try:
with self.assertRaises(final_exception):
yield
task.retry.assert_called_once()
finally:
task.retry = original_retry
workflow = AIGradingWorkflow.objects.get(uuid=self.workflow_uuid)
workflow.algorithm_id = algorithm_id
workflow.save()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment