Commit 01724eb9 by Will Daly

Merge pull request #353 from edx/will/refactor-training-api-call

Will/refactor training api call
parents 2915697b e424c439
......@@ -8,7 +8,9 @@ from openassessment.assessment.serializers import (
from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError
)
from openassessment.assessment.models import AITrainingWorkflow, InvalidOptionSelection
from openassessment.assessment.models import (
AITrainingWorkflow, InvalidOptionSelection, NoTrainingExamples
)
from openassessment.assessment.worker import training as training_tasks
......@@ -89,6 +91,8 @@ def train_classifiers(rubric_dict, examples, algorithm_id):
# Create the workflow model
try:
workflow = AITrainingWorkflow.start_workflow(examples, algorithm_id)
except NoTrainingExamples as ex:
raise AITrainingRequestError(ex)
except:
msg = (
u"An unexpected error occurred while creating "
......
......@@ -7,7 +7,7 @@ from django.db import DatabaseError
from openassessment.assessment.models import (
AITrainingWorkflow, AIClassifierSet,
ClassifierUploadError, ClassifierSerializeError,
IncompleteClassifierSet
IncompleteClassifierSet, NoTrainingExamples
)
from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError
......@@ -185,8 +185,7 @@ def create_classifiers(training_workflow_uuid, classifier_set):
Args:
training_workflow_uuid (str): The UUID of the training workflow.
classifier_set (dict): Mapping of criterion names to serialized classifiers.
(binary classifiers should be base-64 encoded).
classifier_set (dict): Mapping of criteria names to serialized classifiers.
Returns:
None
......@@ -201,47 +200,32 @@ def create_classifiers(training_workflow_uuid, classifier_set):
# If the task is executed multiple times, the classifier set may already
# have been created. If so, log a warning then return immediately.
if workflow.classifier_set is not None:
if workflow.is_complete:
msg = u"AI training workflow with UUID {} already has trained classifiers."
logger.warning(msg)
return
# Retrieve the rubric model
rubric = workflow.rubric
if rubric is None:
msg = (
u"The AI training workflow with UUID {} does not have "
u"a rubric associated with it, which means it has no "
u"training examples."
).format(training_workflow_uuid)
logger.exception(msg)
raise AITrainingInternalError(msg)
try:
workflow.classifier_set = AIClassifierSet.create_classifier_set(
classifier_set, rubric, workflow.algorithm_id
)
except IncompleteClassifierSet as ex:
msg = (
u"An error occurred while creating the classifier set "
u"for the training workflow with UUID {uuid}: {ex}"
).format(uuid=training_workflow_uuid, ex=ex)
raise AITrainingRequestError(msg)
except (ClassifierSerializeError, ClassifierUploadError, DatabaseError) as ex:
msg = (
u"An unexpected error occurred while creating the classifier "
u"set for training workflow UUID {uuid}: {ex}"
).format(uuid=training_workflow_uuid, ex=ex)
logger.exception(msg)
raise AITrainingInternalError(msg)
workflow.completed_at = now()
workflow.save()
else:
workflow.complete(classifier_set)
except AITrainingWorkflow.DoesNotExist:
msg = (
u"Could not retrieve AI training workflow with UUID {}"
).format(training_workflow_uuid)
raise AITrainingRequestError(msg)
except NoTrainingExamples as ex:
logger.exception(ex)
raise AITrainingInternalError(ex)
except IncompleteClassifierSet as ex:
msg = (
u"An error occurred while creating the classifier set "
u"for the training workflow with UUID {uuid}: {ex}"
).format(uuid=training_workflow_uuid, ex=ex)
raise AITrainingRequestError(msg)
except (ClassifierSerializeError, ClassifierUploadError, DatabaseError) as ex:
msg = (
u"An unexpected error occurred while creating the classifier "
u"set for training workflow UUID {uuid}: {ex}"
).format(uuid=training_workflow_uuid, ex=ex)
logger.exception(msg)
raise AITrainingInternalError(msg)
except DatabaseError:
msg = (
u"An unexpected error occurred while creating the classifier set "
......
......@@ -40,6 +40,19 @@ class ClassifierSerializeError(Exception):
pass
class NoTrainingExamples(Exception):
"""
No training examples were provided to the workflow.
"""
def __init__(self, workflow_uuid=None):
msg = u"No training examples were provided"
if workflow_uuid is not None:
msg = u"{msg} to the training workflow with UUID {uuid}".format(
msg=msg, uuid=workflow_uuid
)
super(NoTrainingExamples, self).__init__(msg)
class AIClassifierSet(models.Model):
"""
A set of trained classifiers (immutable).
......@@ -242,7 +255,13 @@ class AITrainingWorkflow(models.Model):
Returns:
AITrainingWorkflow
Raises:
NoTrainingExamples
"""
if len(examples) == 0:
raise NoTrainingExamples()
workflow = AITrainingWorkflow.objects.create(algorithm_id=algorithm_id)
workflow.training_examples.add(*examples)
workflow.save()
......@@ -256,6 +275,9 @@ class AITrainingWorkflow(models.Model):
Returns:
Rubric or None (if no training examples are available)
Raises:
NoTrainingExamples
"""
# We assume that all the training examples we have been provided are using
# the same rubric (this is enforced by the API call that deserializes
......@@ -264,4 +286,38 @@ class AITrainingWorkflow(models.Model):
if first_example:
return first_example[0].rubric
else:
return None
raise NoTrainingExamples(workflow_uuid=self.uuid)
@property
def is_complete(self):
"""
Check whether the workflow is complete (classifiers have been trained).
Returns:
bool
"""
return self.completed_at is not None
def complete(self, classifier_set):
"""
Add a classifier set to the workflow and mark it complete.
Args:
classifier_set (dict): Mapping of criteria names to serialized classifiers.
Returns:
None
Raises:
NoTrainingExamples
IncompleteClassifierSet
ClassifierSerializeError
ClassifierUploadError
DatabaseError
"""
self.classifier_set = AIClassifierSet.create_classifier_set(
classifier_set, self.rubric, self.algorithm_id
)
self.completed_at = now()
self.save()
......@@ -108,6 +108,11 @@ class AITrainingTest(CacheResetTest):
with self.assertRaises(AITrainingRequestError):
ai_api.train_classifiers(RUBRIC, mutated_examples, self.ALGORITHM_ID)
def test_train_classifiers_no_examples(self):
# Empty list of training examples
with self.assertRaises(AITrainingRequestError):
ai_api.train_classifiers(RUBRIC, [], self.ALGORITHM_ID)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch.object(AITrainingWorkflow.objects, 'create')
def test_start_workflow_database_error(self, mock_create):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment