Commit 74a9d63e by Will Daly

Implement AI grading task, database models, and API calls.

parent ec584292
...@@ -2,22 +2,29 @@ ...@@ -2,22 +2,29 @@
Public interface for AI training and grading, used by students/course authors. Public interface for AI training and grading, used by students/course authors.
""" """
import logging import logging
from django.db import DatabaseError
from submissions import api as sub_api
from openassessment.assessment.serializers import ( from openassessment.assessment.serializers import (
deserialize_training_examples, InvalidTrainingExample, InvalidRubric deserialize_training_examples, InvalidTrainingExample, InvalidRubric,
full_assessment_dict
) )
from openassessment.assessment.errors import ( from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError AITrainingRequestError, AITrainingInternalError,
AIGradingRequestError, AIGradingInternalError
) )
from openassessment.assessment.models import ( from openassessment.assessment.models import (
AITrainingWorkflow, InvalidOptionSelection, NoTrainingExamples AITrainingWorkflow, InvalidOptionSelection, NoTrainingExamples,
Assessment, AITrainingWorkflow, AIGradingWorkflow,
AIClassifierSet, AI_ASSESSMENT_TYPE
) )
from openassessment.assessment.worker import training as training_tasks from openassessment.assessment.worker import training as training_tasks
from openassessment.assessment.worker import grading as grading_tasks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def submit(submission_uuid, rubric): def submit(submission_uuid, rubric, algorithm_id):
""" """
Submit a response for AI assessment. Submit a response for AI assessment.
This will: This will:
...@@ -27,6 +34,7 @@ def submit(submission_uuid, rubric): ...@@ -27,6 +34,7 @@ def submit(submission_uuid, rubric):
Args: Args:
submission_uuid (str): The UUID of the submission to assess. submission_uuid (str): The UUID of the submission to assess.
rubric (dict): Serialized rubric model. rubric (dict): Serialized rubric model.
algorithm_id (unicode): Use only classifiers trained with the specified algorithm.
Returns: Returns:
grading_workflow_uuid (str): The UUID of the grading workflow. grading_workflow_uuid (str): The UUID of the grading workflow.
...@@ -39,7 +47,50 @@ def submit(submission_uuid, rubric): ...@@ -39,7 +47,50 @@ def submit(submission_uuid, rubric):
AIGradingInternalError AIGradingInternalError
""" """
pass try:
workflow = AIGradingWorkflow.start_workflow(submission_uuid, rubric, algorithm_id)
except (sub_api.SubmissionNotFoundError, sub_api.SubmissionRequestError) as ex:
msg = (
u"An error occurred while retrieving the "
u"submission with UUID {uuid}: {ex}"
).format(uuid=submission_uuid, ex=ex)
raise AIGradingRequestError(msg)
except InvalidRubric as ex:
msg = (
u"An error occurred while parsing the serialized "
u"rubric {rubric}: {ex}"
).format(rubric=rubric, ex=ex)
raise AIGradingRequestError(msg)
except (sub_api.SubmissionInternalError, DatabaseError) as ex:
msg = (
u"An unexpected error occurred while submitting an "
u"essay for AI grading: {ex}"
).format(ex=ex)
logger.exception(msg)
raise AIGradingInternalError(msg)
try:
classifier_set_candidates = AIClassifierSet.objects.filter(
rubric=workflow.rubric, algorithm_id=algorithm_id
).order_by('-created_at')[:1]
# If we find classifiers for this rubric/algorithm
# then associate the classifiers with the workflow
# and schedule a grading task.
# Otherwise, the task will need to be scheduled later,
# once the classifiers have been trained.
if len(classifier_set_candidates) > 0:
workflow.classifier_set = classifier_set_candidates[0]
workflow.save()
grading_tasks.grade_essay.apply_async(args=[workflow.uuid])
return workflow.uuid
except Exception as ex:
msg = (
u"An unexpected error occurred while scheduling the "
u"AI grading task for the submission with UUID {uuid}: {ex}"
).format(uuid=submission_uuid, ex=ex)
raise AIGradingInternalError(msg)
def get_latest_assessment(submission_uuid): def get_latest_assessment(submission_uuid):
...@@ -51,13 +102,29 @@ def get_latest_assessment(submission_uuid): ...@@ -51,13 +102,29 @@ def get_latest_assessment(submission_uuid):
Returns: Returns:
dict: The serialized assessment model dict: The serialized assessment model
or None if no assessments are available
Raises: Raises:
AIGradingRequestError
AIGradingInternalError AIGradingInternalError
""" """
pass try:
assessments = Assessment.objects.filter(
submission_uuid=submission_uuid,
score_type=AI_ASSESSMENT_TYPE,
)[:1]
except DatabaseError as ex:
msg = (
u"An error occurred while retrieving AI graded assessments "
u"for the submission with UUID {uuid}: {ex}"
).format(uuid=submission_uuid, ex=ex)
logger.exception(msg)
raise AIGradingInternalError(msg)
if len(assessments) > 0:
return full_assessment_dict(assessments[0])
else:
return None
def train_classifiers(rubric_dict, examples, algorithm_id): def train_classifiers(rubric_dict, examples, algorithm_id):
......
...@@ -5,63 +5,85 @@ import logging ...@@ -5,63 +5,85 @@ import logging
from django.utils.timezone import now from django.utils.timezone import now
from django.db import DatabaseError from django.db import DatabaseError
from openassessment.assessment.models import ( from openassessment.assessment.models import (
AITrainingWorkflow, AIClassifierSet, AITrainingWorkflow, AIGradingWorkflow, AIClassifierSet,
ClassifierUploadError, ClassifierSerializeError, ClassifierUploadError, ClassifierSerializeError,
IncompleteClassifierSet, NoTrainingExamples IncompleteClassifierSet, NoTrainingExamples
) )
from openassessment.assessment.errors import ( from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError AITrainingRequestError, AITrainingInternalError,
AIGradingRequestError, AIGradingInternalError
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_grading_task_params(grading_workflow_uuid):
def get_submission(grading_workflow_uuid):
""" """
Retrieve the submission associated with a particular grading workflow. Retrieve the classifier set and algorithm ID
associated with a particular grading workflow.
Args: Args:
grading_workflow_uuid (str): The UUID of the grading workflow. grading_workflow_uuid (str): The UUID of the grading workflow.
Returns: Returns:
submission (JSON-serializable): submission from the student. dict with keys:
* essay_text (unicode): The text of the essay submission.
* classifier_set (dict): Maps criterion names to serialized classifiers.
* algorithm_id (unicode): ID of the algorithm used to perform training.
Raises: Raises:
AIGradingRequestError AIGradingRequestError
AIGradingInternalError AIGradingInternalError
""" """
pass try:
workflow = AIGradingWorkflow.objects.get(uuid=grading_workflow_uuid)
except AIGradingWorkflow.DoesNotExist:
def get_classifier_set(grading_workflow_uuid): msg = (
""" u"Could not retrieve the AI grading workflow with uuid {}"
Retrieve the classifier set associated with a particular grading workflow. ).format(grading_workflow_uuid)
raise AIGradingRequestError(msg)
Args: except DatabaseError as ex:
grading_workflow_uuid (str): The UUID of the grading workflow. msg = (
u"An unexpected error occurred while retrieving the "
Returns: u"AI grading workflow with uuid {uuid}: {ex}"
dict: Maps criterion names to serialized classifiers. ).format(uuid=grading_workflow_uuid, ex=ex)
(binary classifiers are base-64 encoded). logger.exception(msg)
raise AIGradingInternalError(msg)
Raises: classifier_set = workflow.classifier_set
AIGradingRequestError # Tasks shouldn't be scheduled until a classifier set is
AIGradingInternalError # available, so this is a serious internal error.
if classifier_set is None:
msg = (
u"AI grading workflow with UUID {} has no classifier set"
).format(grading_workflow_uuid)
logger.exception(msg)
raise AIGradingInternalError(msg)
""" try:
pass return {
'essay_text': workflow.essay_text,
'classifier_set': classifier_set.classifiers_dict,
'algorithm_id': workflow.algorithm_id,
}
except (ValueError, IOError, DatabaseError) as ex:
msg = (
u"An unexpected error occurred while retrieving "
u"classifiers for the grading workflow with UUID {uuid}: {ex}"
).format(uuid=grading_workflow_uuid, ex=ex)
logger.exception(msg)
raise AIGradingInternalError(msg)
def create_assessment(grading_workflow_uuid, assessment): def create_assessment(grading_workflow_uuid, criterion_scores):
""" """
Create an AI assessment (complete the AI grading task). Create an AI assessment (complete the AI grading task).
Args: Args:
grading_workflow_uuid (str): The UUID of the grading workflow. grading_workflow_uuid (str): The UUID of the grading workflow.
assessment (dict): The serialized assessment. criterion_scores (dict): Dictionary mapping criteria names to integer scores.
Returns: Returns:
None None
...@@ -71,57 +93,59 @@ def create_assessment(grading_workflow_uuid, assessment): ...@@ -71,57 +93,59 @@ def create_assessment(grading_workflow_uuid, assessment):
AIGradingInternalError AIGradingInternalError
""" """
pass
def get_algorithm_id(training_workflow_uuid):
"""
Retrieve the ID of the algorithm to use.
Args:
training_workflow_uuid (str): The UUID of the training workflow.
Returns:
unicode: The algorithm ID associated with the training task.
Raises:
AITrainingRequestError
AITrainingInternalError
"""
try: try:
workflow = AITrainingWorkflow.objects.get(uuid=training_workflow_uuid) workflow = AIGradingWorkflow.objects.get(uuid=grading_workflow_uuid)
return workflow.algorithm_id except AIGradingWorkflow.DoesNotExist:
except AITrainingWorkflow.DoesNotExist:
msg = ( msg = (
u"Could not retrieve AI training workflow with UUID {}" u"Could not retrieve the AI grading workflow with uuid {}"
).format(training_workflow_uuid) ).format(grading_workflow_uuid)
raise AITrainingRequestError(msg) raise AIGradingRequestError(msg)
except DatabaseError: except DatabaseError as ex:
msg = ( msg = (
u"An unexpected error occurred while retrieving " u"An unexpected error occurred while retrieving the "
u"the algorithm ID for training workflow with UUID {}" u"AI grading workflow with uuid {uuid}: {ex}"
).format(training_workflow_uuid) ).format(uuid=grading_workflow_uuid, ex=ex)
logger.exception(msg) logger.exception(msg)
raise AITrainingInternalError(msg) raise AIGradingInternalError(msg)
# Optimization: if the workflow has already been marked complete
# (perhaps the task was picked up by multiple workers),
# then we don't need to do anything.
# Otherwise, create the assessment mark the workflow complete.
try:
if not workflow.is_complete:
workflow.complete(criterion_scores)
except DatabaseError as ex:
msg = (
u"An unexpected error occurred while creating the assessment "
u"for AI grading workflow with uuid {uuid}: {ex}"
).format(uuid=grading_workflow_uuid, ex=ex)
logger.exception(msg)
raise AIGradingInternalError(msg)
def get_training_examples(training_workflow_uuid): def get_training_task_params(training_workflow_uuid):
""" """
Retrieve the training examples associated with a training task. Retrieve the training examples and algorithm ID
associated with a training task.
Args: Args:
training_workflow_uuid (str): The UUID of the training workflow. training_workflow_uuid (str): The UUID of the training workflow.
Returns: Returns:
list of dict: Serialized training examples, of the form: dict with keys:
* training_examples (list of dict): The examples used to train the classifiers.
* algorithm_id (unicode): The ID of the algorithm to use for training.
Raises: Raises:
AITrainingRequestError AITrainingRequestError
AITrainingInternalError AITrainingInternalError
Example usage: Example usage:
>>> get_training_examples('abcd1234') >>> params = get_training_task_params('abcd1234')
>>> params['algorithm_id']
u'ease'
>>> params['training_examples']
[ [
{ {
"text": u"Example answer number one", "text": u"Example answer number one",
...@@ -161,7 +185,10 @@ def get_training_examples(training_workflow_uuid): ...@@ -161,7 +185,10 @@ def get_training_examples(training_workflow_uuid):
'scores': scores 'scores': scores
}) })
return returned_examples return {
'training_examples': returned_examples,
'algorithm_id': workflow.algorithm_id
}
except AITrainingWorkflow.DoesNotExist: except AITrainingWorkflow.DoesNotExist:
msg = ( msg = (
u"Could not retrieve AI training workflow with UUID {}" u"Could not retrieve AI training workflow with UUID {}"
......
...@@ -166,6 +166,48 @@ class Rubric(models.Model): ...@@ -166,6 +166,48 @@ class Rubric(models.Model):
return option_id_set return option_id_set
def options_ids_for_points(self, criterion_points):
"""
Given a mapping of selected point values, return the option IDs.
If there are multiple options with the same point value,
this will return the first one (lower order number).
Args:
criterion_points (dict): Mapping of criteria names to point values.
Returns:
list of option IDs
Raises:
InvalidOptionSelection
"""
# This is a really inefficient initial implementation
# TODO -- refactor to add caching
rubric_options = CriterionOption.objects.filter(
criterion__rubric=self
).select_related()
rubric_points_dict = defaultdict(dict)
for option in rubric_options:
if option.points not in rubric_points_dict[option.criterion.name]:
rubric_points_dict[option.criterion.name][option.points] = option.id
option_id_set = set()
for criterion_name, option_points in criterion_points.iteritems():
if (criterion_name in rubric_points_dict and
option_points in rubric_points_dict[criterion_name]
):
option_id = rubric_points_dict[criterion_name][option_points]
option_id_set.add(option_id)
else:
msg = _("{criterion} option with point value {points} not found in rubric").format(
criterion=criterion_name, points=option_points
)
raise InvalidOptionSelection(msg)
return option_id_set
class Criterion(models.Model): class Criterion(models.Model):
"""A single aspect of a submission that needs assessment. """A single aspect of a submission that needs assessment.
......
...@@ -4,3 +4,4 @@ so import the tasks we want the workers to implement. ...@@ -4,3 +4,4 @@ so import the tasks we want the workers to implement.
""" """
# pylint:disable=W0611 # pylint:disable=W0611
from .worker.training import train_classifiers from .worker.training import train_classifiers
from .worker.grading import grade_essay
...@@ -7,11 +7,13 @@ import mock ...@@ -7,11 +7,13 @@ import mock
from django.db import DatabaseError from django.db import DatabaseError
from django.test.utils import override_settings from django.test.utils import override_settings
from openassessment.test_utils import CacheResetTest from openassessment.test_utils import CacheResetTest
from submissions import api as sub_api
from openassessment.assessment.api import ai as ai_api from openassessment.assessment.api import ai as ai_api
from openassessment.assessment.models import AITrainingWorkflow from openassessment.assessment.models import AITrainingWorkflow, AIClassifierSet
from openassessment.assessment.worker.algorithm import AIAlgorithm from openassessment.assessment.worker.algorithm import AIAlgorithm
from openassessment.assessment.serializers import rubric_from_dict
from openassessment.assessment.errors import AITrainingRequestError, AITrainingInternalError from openassessment.assessment.errors import AITrainingRequestError, AITrainingInternalError
from openassessment.assessment.test.constants import RUBRIC, EXAMPLES from openassessment.assessment.test.constants import RUBRIC, EXAMPLES, STUDENT_ITEM, ANSWER
class StubAIAlgorithm(AIAlgorithm): class StubAIAlgorithm(AIAlgorithm):
...@@ -34,14 +36,25 @@ class StubAIAlgorithm(AIAlgorithm): ...@@ -34,14 +36,25 @@ class StubAIAlgorithm(AIAlgorithm):
# so we can test that the correct inputs were used # so we can test that the correct inputs were used
classifier = copy.copy(self.FAKE_CLASSIFIER) classifier = copy.copy(self.FAKE_CLASSIFIER)
classifier['examples'] = examples classifier['examples'] = examples
classifier['score_override'] = 0
return classifier return classifier
def score(self, text, classifier): def score(self, text, classifier):
""" """
Not implemented, but we need to make the abstact Stub implementation that returns whatever scores were
method concrete. provided in the serialized classifier data.
Expect `classifier` to be a dict with a single key,
"score_override" containing the score to return.
""" """
raise NotImplementedError return classifier['score_override']
ALGORITHM_ID = "test-stub"
AI_ALGORITHMS = {
ALGORITHM_ID: '{module}.StubAIAlgorithm'.format(module=__name__)
}
class AITrainingTest(CacheResetTest): class AITrainingTest(CacheResetTest):
...@@ -49,11 +62,6 @@ class AITrainingTest(CacheResetTest): ...@@ -49,11 +62,6 @@ class AITrainingTest(CacheResetTest):
Tests for AI training tasks. Tests for AI training tasks.
""" """
ALGORITHM_ID = "test-stub"
AI_ALGORITHMS = {
ALGORITHM_ID: '{module}.StubAIAlgorithm'.format(module=__name__)
}
EXPECTED_INPUT_SCORES = { EXPECTED_INPUT_SCORES = {
u'vøȼȺƀᵾłȺɍɏ': [1, 0], u'vøȼȺƀᵾłȺɍɏ': [1, 0],
u'ﻭɼค๓๓คɼ': [0, 2] u'ﻭɼค๓๓คɼ': [0, 2]
...@@ -65,7 +73,7 @@ class AITrainingTest(CacheResetTest): ...@@ -65,7 +73,7 @@ class AITrainingTest(CacheResetTest):
# Schedule a training task # Schedule a training task
# Because Celery is configured in "always eager" mode, # Because Celery is configured in "always eager" mode,
# expect the task to be executed synchronously. # expect the task to be executed synchronously.
workflow_uuid = ai_api.train_classifiers(RUBRIC, EXAMPLES, self.ALGORITHM_ID) workflow_uuid = ai_api.train_classifiers(RUBRIC, EXAMPLES, ALGORITHM_ID)
# Retrieve the classifier set from the database # Retrieve the classifier set from the database
workflow = AITrainingWorkflow.objects.get(uuid=workflow_uuid) workflow = AITrainingWorkflow.objects.get(uuid=workflow_uuid)
...@@ -106,12 +114,12 @@ class AITrainingTest(CacheResetTest): ...@@ -106,12 +114,12 @@ class AITrainingTest(CacheResetTest):
# Expect a request error # Expect a request error
with self.assertRaises(AITrainingRequestError): with self.assertRaises(AITrainingRequestError):
ai_api.train_classifiers(RUBRIC, mutated_examples, self.ALGORITHM_ID) ai_api.train_classifiers(RUBRIC, mutated_examples, ALGORITHM_ID)
def test_train_classifiers_no_examples(self): def test_train_classifiers_no_examples(self):
# Empty list of training examples # Empty list of training examples
with self.assertRaises(AITrainingRequestError): with self.assertRaises(AITrainingRequestError):
ai_api.train_classifiers(RUBRIC, [], self.ALGORITHM_ID) ai_api.train_classifiers(RUBRIC, [], ALGORITHM_ID)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS) @override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch.object(AITrainingWorkflow.objects, 'create') @mock.patch.object(AITrainingWorkflow.objects, 'create')
...@@ -119,7 +127,7 @@ class AITrainingTest(CacheResetTest): ...@@ -119,7 +127,7 @@ class AITrainingTest(CacheResetTest):
# Simulate a database error when creating the training workflow # Simulate a database error when creating the training workflow
mock_create.side_effect = DatabaseError("KABOOM!") mock_create.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AITrainingInternalError): with self.assertRaises(AITrainingInternalError):
ai_api.train_classifiers(RUBRIC, EXAMPLES, self.ALGORITHM_ID) ai_api.train_classifiers(RUBRIC, EXAMPLES, ALGORITHM_ID)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS) @override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch('openassessment.assessment.api.ai.training_tasks') @mock.patch('openassessment.assessment.api.ai.training_tasks')
...@@ -127,4 +135,47 @@ class AITrainingTest(CacheResetTest): ...@@ -127,4 +135,47 @@ class AITrainingTest(CacheResetTest):
# Simulate an exception raised when scheduling a training task # Simulate an exception raised when scheduling a training task
mock_training_tasks.train_classifiers.apply_async.side_effect = Exception("KABOOM!") mock_training_tasks.train_classifiers.apply_async.side_effect = Exception("KABOOM!")
with self.assertRaises(AITrainingInternalError): with self.assertRaises(AITrainingInternalError):
ai_api.train_classifiers(RUBRIC, EXAMPLES, self.ALGORITHM_ID) ai_api.train_classifiers(RUBRIC, EXAMPLES, ALGORITHM_ID)
class AIGradingTest(CacheResetTest):
"""
Tests for AI grading tasks.
"""
CLASSIFIER_SCORE_OVERRIDES = {
u"vøȼȺƀᵾłȺɍɏ": {'score_override': 1},
u"ﻭɼค๓๓คɼ": {'score_override': 2}
}
def setUp(self):
"""
Create a submission and a fake classifier set.
"""
# Create a submission
submission = sub_api.create_submission(STUDENT_ITEM, ANSWER)
self.submission_uuid = submission['uuid']
# Create the classifier set for our fake AI algorithm
# To isolate these tests from the tests for the training
# task, we use the database models directly.
# We also use a stub AI algorithm that simply returns
# whatever scores we specify in the classifier data.
rubric = rubric_from_dict(RUBRIC)
AIClassifierSet.create_classifier_set(
self.CLASSIFIER_SCORE_OVERRIDES, rubric, ALGORITHM_ID
)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_grade_essay(self):
# Schedule a grading task
# Because Celery is configured in "always eager" mode, this will
# be executed synchronously.
ai_api.submit(self.submission_uuid, RUBRIC, ALGORITHM_ID)
# Verify that we got the scores we provided to the stub AI algorithm
assessment = ai_api.get_latest_assessment(self.submission_uuid)
for part in assessment['parts']:
criterion_name = part['option']['criterion']['name']
expected_score = self.CLASSIFIER_SCORE_OVERRIDES[criterion_name]['score_override']
self.assertEqual(part['option']['points'], expected_score)
...@@ -45,22 +45,8 @@ class AIWorkerTrainingTest(CacheResetTest): ...@@ -45,22 +45,8 @@ class AIWorkerTrainingTest(CacheResetTest):
workflow = AITrainingWorkflow.start_workflow(examples, self.ALGORITHM_ID) workflow = AITrainingWorkflow.start_workflow(examples, self.ALGORITHM_ID)
self.workflow_uuid = workflow.uuid self.workflow_uuid = workflow.uuid
def test_get_algorithm_id(self): def test_get_training_task_params(self):
algorithm_id = ai_worker_api.get_algorithm_id(self.workflow_uuid) params = ai_worker_api.get_training_task_params(self.workflow_uuid)
self.assertEqual(algorithm_id, self.ALGORITHM_ID)
def test_get_algorithm_id_no_workflow(self):
with self.assertRaises(AITrainingRequestError):
ai_worker_api.get_algorithm_id("invalid_uuid")
@mock.patch.object(AITrainingWorkflow.objects, 'get')
def test_get_algorithm_id_database_error(self, mock_get):
mock_get.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AITrainingInternalError):
ai_worker_api.get_algorithm_id(self.workflow_uuid)
def test_get_training_examples(self):
examples = ai_worker_api.get_training_examples(self.workflow_uuid)
expected_examples = [ expected_examples = [
{ {
'text': EXAMPLES[0]['answer'], 'text': EXAMPLES[0]['answer'],
...@@ -77,17 +63,18 @@ class AIWorkerTrainingTest(CacheResetTest): ...@@ -77,17 +63,18 @@ class AIWorkerTrainingTest(CacheResetTest):
} }
}, },
] ]
self.assertItemsEqual(examples, expected_examples) self.assertItemsEqual(params['training_examples'], expected_examples)
self.assertItemsEqual(params['algorithm_id'], self.ALGORITHM_ID)
def test_get_training_examples_no_workflow(self): def test_get_training_task_params_no_workflow(self):
with self.assertRaises(AITrainingRequestError): with self.assertRaises(AITrainingRequestError):
ai_worker_api.get_training_examples("invalid_uuid") ai_worker_api.get_training_task_params("invalid_uuid")
@mock.patch.object(AITrainingWorkflow.objects, 'get') @mock.patch.object(AITrainingWorkflow.objects, 'get')
def test_get_training_examples_database_error(self, mock_get): def test_get_training_task_params_database_error(self, mock_get):
mock_get.side_effect = DatabaseError("KABOOM!") mock_get.side_effect = DatabaseError("KABOOM!")
with self.assertRaises(AITrainingInternalError): with self.assertRaises(AITrainingInternalError):
ai_worker_api.get_training_examples(self.workflow_uuid) ai_worker_api.get_training_task_params(self.workflow_uuid)
def test_create_classifiers(self): def test_create_classifiers(self):
ai_worker_api.create_classifiers(self.workflow_uuid, self.CLASSIFIERS) ai_worker_api.create_classifiers(self.workflow_uuid, self.CLASSIFIERS)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
Tests for AI worker tasks. Tests for AI worker tasks.
""" """
from contextlib import contextmanager from contextlib import contextmanager
import datetime
import mock import mock
from django.test.utils import override_settings from django.test.utils import override_settings
from openassessment.test_utils import CacheResetTest from openassessment.test_utils import CacheResetTest
...@@ -64,13 +63,6 @@ class AITrainingTaskTest(CacheResetTest): ...@@ -64,13 +63,6 @@ class AITrainingTaskTest(CacheResetTest):
workflow = AITrainingWorkflow.start_workflow(examples, self.ALGORITHM_ID) workflow = AITrainingWorkflow.start_workflow(examples, self.ALGORITHM_ID)
self.workflow_uuid = workflow.uuid self.workflow_uuid = workflow.uuid
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch('openassessment.assessment.worker.training.ai_worker_api.get_algorithm_id')
def test_get_algorithm_id_api_error(self, mock_call):
mock_call.side_effect = AITrainingRequestError("Test error!")
with self._assert_retry(train_classifiers, AITrainingRequestError):
train_classifiers(self.workflow_uuid)
def test_unknown_algorithm(self): def test_unknown_algorithm(self):
# Since we haven't overridden settings to configure the algorithms, # Since we haven't overridden settings to configure the algorithms,
# the worker will not recognize the workflow's algorithm ID. # the worker will not recognize the workflow's algorithm ID.
...@@ -92,8 +84,8 @@ class AITrainingTaskTest(CacheResetTest): ...@@ -92,8 +84,8 @@ class AITrainingTaskTest(CacheResetTest):
train_classifiers(self.workflow_uuid) train_classifiers(self.workflow_uuid)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS) @override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch('openassessment.assessment.worker.training.ai_worker_api.get_training_examples') @mock.patch('openassessment.assessment.worker.training.ai_worker_api.get_training_task_params')
def test_get_training_examples_api_error(self, mock_call): def test_get_training_task_params_api_error(self, mock_call):
mock_call.side_effect = AITrainingRequestError("Test error!") mock_call.side_effect = AITrainingRequestError("Test error!")
with self._assert_retry(train_classifiers, AITrainingRequestError): with self._assert_retry(train_classifiers, AITrainingRequestError):
train_classifiers(self.workflow_uuid) train_classifiers(self.workflow_uuid)
...@@ -160,12 +152,12 @@ class AITrainingTaskTest(CacheResetTest): ...@@ -160,12 +152,12 @@ class AITrainingTaskTest(CacheResetTest):
AssertionError AssertionError
""" """
examples = ai_worker_api.get_training_examples(self.workflow_uuid) params = ai_worker_api.get_training_task_params(self.workflow_uuid)
mutate_func(examples) mutate_func(params['training_examples'])
call_signature = 'openassessment.assessment.worker.training.ai_worker_api.get_training_examples' call_signature = 'openassessment.assessment.worker.training.ai_worker_api.get_training_task_params'
with mock.patch(call_signature) as mock_call: with mock.patch(call_signature) as mock_call:
mock_call.return_value = examples mock_call.return_value = params
with self._assert_retry(train_classifiers, InvalidExample): with self._assert_retry(train_classifiers, InvalidExample):
train_classifiers(self.workflow_uuid) train_classifiers(self.workflow_uuid)
......
"""
Asynchronous tasks for grading essays using text classifiers.
"""
from celery import task
from celery.utils.log import get_task_logger
from openassessment.assessment.api import ai_worker as ai_worker_api
from openassessment.assessment.errors import AIError
from .algorithm import AIAlgorithm, AIAlgorithmError
MAX_RETRIES = 2
logger = get_task_logger(__name__)
@task(max_retries=MAX_RETRIES) # pylint: disable=E1102
def grade_essay(workflow_uuid):
"""
Asynchronous task to grade an essay using a text classifier
(trained using a supervised ML algorithm).
If the task could not be completed successfully,
it will be retried a few times; if it continues to fail,
it is left incomplete. Incomplate tasks can be rescheduled
manually through the AI API.
Args:
workflow_uuid (str): The UUID of the workflow associated
with this grading task.
Returns:
None
Raises:
AIError: An error occurred while making an AI worker API call.
AIAlgorithmError: An error occurred while retrieving or using an AI algorithm.
"""
# Retrieve the task parameters
try:
params = ai_worker_api.get_grading_task_params(workflow_uuid)
essay_text = params['essay_text']
classifier_set = params['classifier_set']
algorithm_id = params['algorithm_id']
except (AIError, KeyError):
msg = (
u"An error occurred while retrieving the AI grading task "
u"parameters for the workflow with UUID {}"
).format(workflow_uuid)
logger.exception(msg)
raise grade_essay.retry()
# Retrieve the AI algorithm
try:
algorithm = AIAlgorithm.algorithm_for_id(algorithm_id)
except AIAlgorithmError:
msg = (
u"An error occurred while retrieving "
u"the algorithm ID (grading workflow UUID {})"
).format(workflow_uuid)
logger.exception(msg)
raise grade_essay.retry()
# Use the algorithm to evaluate the essay for each criterion
try:
scores_by_criterion = {
criterion_name: algorithm.score(essay_text, classifier)
for criterion_name, classifier in classifier_set.iteritems()
}
except AIAlgorithmError:
msg = (
u"An error occurred while scoring essays using "
u"an AI algorithm (worker workflow UUID {})"
).format(workflow_uuid)
logger.exception(msg)
raise grade_essay.retry()
# Create the assessment and mark the workflow complete
try:
ai_worker_api.create_assessment(workflow_uuid, scores_by_criterion)
except AIError:
msg = (
u"An error occurred while creating assessments "
u"for the AI grading workflow with UUID {uuid}. "
u"The assessment scores were: {scores}"
).format(uuid=workflow_uuid, scores=scores_by_criterion)
logger.exception(msg)
raise grade_essay.retry()
...@@ -53,10 +53,22 @@ def train_classifiers(workflow_uuid): ...@@ -53,10 +53,22 @@ def train_classifiers(workflow_uuid):
InvalidExample: The training examples provided by the AI API were not valid. InvalidExample: The training examples provided by the AI API were not valid.
""" """
# Retrieve task parameters
try:
params = ai_worker_api.get_training_task_params(workflow_uuid)
examples = params['training_examples']
algorithm_id = params['algorithm_id']
except (AIError, KeyError):
msg = (
u"An error occurred while retrieving AI training "
u"task parameters for the workflow with UUID {}"
).format(workflow_uuid)
logger.exception(msg)
raise train_classifiers.retry()
# Retrieve the ML algorithm to use for training # Retrieve the ML algorithm to use for training
# (based on task params and worker configuration) # (based on task params and worker configuration)
try: try:
algorithm_id = ai_worker_api.get_algorithm_id(workflow_uuid)
algorithm = AIAlgorithm.algorithm_for_id(algorithm_id) algorithm = AIAlgorithm.algorithm_for_id(algorithm_id)
except AIAlgorithmError: except AIAlgorithmError:
msg = ( msg = (
...@@ -73,19 +85,6 @@ def train_classifiers(workflow_uuid): ...@@ -73,19 +85,6 @@ def train_classifiers(workflow_uuid):
logger.exception(msg) logger.exception(msg)
raise train_classifiers.retry() raise train_classifiers.retry()
# Retrieve training examples, then transform them into the
# data structures we use internally.
try:
examples = ai_worker_api.get_training_examples(workflow_uuid)
except AIError:
msg = (
u"An error occurred while retrieving "
u"training examples for AI training "
u"(training workflow UUID {})"
).format(workflow_uuid)
logger.exception(msg)
raise train_classifiers.retry()
# Train a classifier for each criterion # Train a classifier for each criterion
# The AIAlgorithm subclass is responsible for ensuring that # The AIAlgorithm subclass is responsible for ensuring that
# the trained classifiers are JSON-serializable. # the trained classifiers are JSON-serializable.
......
...@@ -240,17 +240,15 @@ Data Model ...@@ -240,17 +240,15 @@ Data Model
1. **GradingWorkflow** 1. **GradingWorkflow**
a. Submission UUID (varchar) a. Submission UUID (varchar)
b. Rubric UUID (varchar) b. ClassifierSet (Foreign Key, Nullable)
c. ClassifierSet (Foreign Key, Nullable) c. Assessment (Foreign Key, Nullable)
d. Assessment (Foreign Key, Nullable) d. Rubric (Foreign Key): Used to search for classifier sets if none are available when the workflow is started.
e. Scheduled at (timestamp): The time the task was placed on the queue. e. Algorithm ID (varchar): Used to search for classifier sets if none are available when the workflow is started.
f. Started at (timestamp): The time the task was picked up by the worker. f. Scheduled at (timestamp): The time the task was placed on the queue.
g. Completed at (timestamp): The time the task was completed. If set, the task is considered complete. g. Completed at (timestamp): The time the task was completed. If set, the task is considered complete.
h. Course ID (varchar): The ID of the course associated with the submission. Useful for rescheduling h. Course ID (varchar): The ID of the course associated with the submission. Useful for rescheduling failed grading tasks in a particular course.
failed grading tasks in a particular course. i. Item ID (varchar): The ID of the item (problem) associated with the submission. Useful for rescheduling failed grading tasks in a particular item in a course.
i. Item ID (varchar): The ID of the item (problem) associated with the submission. Useful for rescheduling
failed grading tasks in a particular item in a course.
j. Worker version (varchar): Identifier for the code running on the worker when the task was started. Useful for error tracking.
2. **TrainingWorkflow** 2. **TrainingWorkflow**
...@@ -269,13 +267,13 @@ Data Model ...@@ -269,13 +267,13 @@ Data Model
a. Rubric (Foreign Key) a. Rubric (Foreign Key)
b. Created at (timestamp) b. Created at (timestamp)
c. Algorithm ID (varchar)
5. **Classifier** 5. **Classifier**
a. ClassifierSet (Foreign Key) a. ClassifierSet (Foreign Key)
b. URL for trained classifier (varchar) b. URL for trained classifier (varchar)
c. Algorithm ID (varchar) c. Criterion (Foreign Key)
d. Criterion (Foreign Key)
6. **Assessment** (same as current implementation) 6. **Assessment** (same as current implementation)
......
...@@ -4,5 +4,5 @@ ...@@ -4,5 +4,5 @@
set -e set -e
cd `dirname $BASH_SOURCE` && cd .. cd `dirname $BASH_SOURCE` && cd ..
./scripts/test-python.sh ./scripts/test-python.sh $1
./scripts/test-js.sh ./scripts/test-js.sh
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment