Commit 01724eb9 by Will Daly

Merge pull request #353 from edx/will/refactor-training-api-call

Will/refactor training api call
parents 2915697b e424c439
...@@ -8,7 +8,9 @@ from openassessment.assessment.serializers import ( ...@@ -8,7 +8,9 @@ from openassessment.assessment.serializers import (
from openassessment.assessment.errors import ( from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError AITrainingRequestError, AITrainingInternalError
) )
from openassessment.assessment.models import AITrainingWorkflow, InvalidOptionSelection from openassessment.assessment.models import (
AITrainingWorkflow, InvalidOptionSelection, NoTrainingExamples
)
from openassessment.assessment.worker import training as training_tasks from openassessment.assessment.worker import training as training_tasks
...@@ -89,6 +91,8 @@ def train_classifiers(rubric_dict, examples, algorithm_id): ...@@ -89,6 +91,8 @@ def train_classifiers(rubric_dict, examples, algorithm_id):
# Create the workflow model # Create the workflow model
try: try:
workflow = AITrainingWorkflow.start_workflow(examples, algorithm_id) workflow = AITrainingWorkflow.start_workflow(examples, algorithm_id)
except NoTrainingExamples as ex:
raise AITrainingRequestError(ex)
except: except:
msg = ( msg = (
u"An unexpected error occurred while creating " u"An unexpected error occurred while creating "
......
...@@ -7,7 +7,7 @@ from django.db import DatabaseError ...@@ -7,7 +7,7 @@ from django.db import DatabaseError
from openassessment.assessment.models import ( from openassessment.assessment.models import (
AITrainingWorkflow, AIClassifierSet, AITrainingWorkflow, AIClassifierSet,
ClassifierUploadError, ClassifierSerializeError, ClassifierUploadError, ClassifierSerializeError,
IncompleteClassifierSet IncompleteClassifierSet, NoTrainingExamples
) )
from openassessment.assessment.errors import ( from openassessment.assessment.errors import (
AITrainingRequestError, AITrainingInternalError AITrainingRequestError, AITrainingInternalError
...@@ -185,8 +185,7 @@ def create_classifiers(training_workflow_uuid, classifier_set): ...@@ -185,8 +185,7 @@ def create_classifiers(training_workflow_uuid, classifier_set):
Args: Args:
training_workflow_uuid (str): The UUID of the training workflow. training_workflow_uuid (str): The UUID of the training workflow.
classifier_set (dict): Mapping of criterion names to serialized classifiers. classifier_set (dict): Mapping of criteria names to serialized classifiers.
(binary classifiers should be base-64 encoded).
Returns: Returns:
None None
...@@ -201,26 +200,19 @@ def create_classifiers(training_workflow_uuid, classifier_set): ...@@ -201,26 +200,19 @@ def create_classifiers(training_workflow_uuid, classifier_set):
# If the task is executed multiple times, the classifier set may already # If the task is executed multiple times, the classifier set may already
# have been created. If so, log a warning then return immediately. # have been created. If so, log a warning then return immediately.
if workflow.classifier_set is not None: if workflow.is_complete:
msg = u"AI training workflow with UUID {} already has trained classifiers." msg = u"AI training workflow with UUID {} already has trained classifiers."
logger.warning(msg) logger.warning(msg)
return else:
workflow.complete(classifier_set)
# Retrieve the rubric model except AITrainingWorkflow.DoesNotExist:
rubric = workflow.rubric
if rubric is None:
msg = ( msg = (
u"The AI training workflow with UUID {} does not have " u"Could not retrieve AI training workflow with UUID {}"
u"a rubric associated with it, which means it has no "
u"training examples."
).format(training_workflow_uuid) ).format(training_workflow_uuid)
logger.exception(msg) raise AITrainingRequestError(msg)
raise AITrainingInternalError(msg) except NoTrainingExamples as ex:
logger.exception(ex)
try: raise AITrainingInternalError(ex)
workflow.classifier_set = AIClassifierSet.create_classifier_set(
classifier_set, rubric, workflow.algorithm_id
)
except IncompleteClassifierSet as ex: except IncompleteClassifierSet as ex:
msg = ( msg = (
u"An error occurred while creating the classifier set " u"An error occurred while creating the classifier set "
...@@ -234,14 +226,6 @@ def create_classifiers(training_workflow_uuid, classifier_set): ...@@ -234,14 +226,6 @@ def create_classifiers(training_workflow_uuid, classifier_set):
).format(uuid=training_workflow_uuid, ex=ex) ).format(uuid=training_workflow_uuid, ex=ex)
logger.exception(msg) logger.exception(msg)
raise AITrainingInternalError(msg) raise AITrainingInternalError(msg)
workflow.completed_at = now()
workflow.save()
except AITrainingWorkflow.DoesNotExist:
msg = (
u"Could not retrieve AI training workflow with UUID {}"
).format(training_workflow_uuid)
raise AITrainingRequestError(msg)
except DatabaseError: except DatabaseError:
msg = ( msg = (
u"An unexpected error occurred while creating the classifier set " u"An unexpected error occurred while creating the classifier set "
......
...@@ -40,6 +40,19 @@ class ClassifierSerializeError(Exception): ...@@ -40,6 +40,19 @@ class ClassifierSerializeError(Exception):
pass pass
class NoTrainingExamples(Exception):
"""
No training examples were provided to the workflow.
"""
def __init__(self, workflow_uuid=None):
msg = u"No training examples were provided"
if workflow_uuid is not None:
msg = u"{msg} to the training workflow with UUID {uuid}".format(
msg=msg, uuid=workflow_uuid
)
super(NoTrainingExamples, self).__init__(msg)
class AIClassifierSet(models.Model): class AIClassifierSet(models.Model):
""" """
A set of trained classifiers (immutable). A set of trained classifiers (immutable).
...@@ -242,7 +255,13 @@ class AITrainingWorkflow(models.Model): ...@@ -242,7 +255,13 @@ class AITrainingWorkflow(models.Model):
Returns: Returns:
AITrainingWorkflow AITrainingWorkflow
Raises:
NoTrainingExamples
""" """
if len(examples) == 0:
raise NoTrainingExamples()
workflow = AITrainingWorkflow.objects.create(algorithm_id=algorithm_id) workflow = AITrainingWorkflow.objects.create(algorithm_id=algorithm_id)
workflow.training_examples.add(*examples) workflow.training_examples.add(*examples)
workflow.save() workflow.save()
...@@ -256,6 +275,9 @@ class AITrainingWorkflow(models.Model): ...@@ -256,6 +275,9 @@ class AITrainingWorkflow(models.Model):
Returns: Returns:
Rubric or None (if no training examples are available) Rubric or None (if no training examples are available)
Raises:
NoTrainingExamples
""" """
# We assume that all the training examples we have been provided are using # We assume that all the training examples we have been provided are using
# the same rubric (this is enforced by the API call that deserializes # the same rubric (this is enforced by the API call that deserializes
...@@ -264,4 +286,38 @@ class AITrainingWorkflow(models.Model): ...@@ -264,4 +286,38 @@ class AITrainingWorkflow(models.Model):
if first_example: if first_example:
return first_example[0].rubric return first_example[0].rubric
else: else:
return None raise NoTrainingExamples(workflow_uuid=self.uuid)
@property
def is_complete(self):
"""
Check whether the workflow is complete (classifiers have been trained).
Returns:
bool
"""
return self.completed_at is not None
def complete(self, classifier_set):
"""
Add a classifier set to the workflow and mark it complete.
Args:
classifier_set (dict): Mapping of criteria names to serialized classifiers.
Returns:
None
Raises:
NoTrainingExamples
IncompleteClassifierSet
ClassifierSerializeError
ClassifierUploadError
DatabaseError
"""
self.classifier_set = AIClassifierSet.create_classifier_set(
classifier_set, self.rubric, self.algorithm_id
)
self.completed_at = now()
self.save()
...@@ -108,6 +108,11 @@ class AITrainingTest(CacheResetTest): ...@@ -108,6 +108,11 @@ class AITrainingTest(CacheResetTest):
with self.assertRaises(AITrainingRequestError): with self.assertRaises(AITrainingRequestError):
ai_api.train_classifiers(RUBRIC, mutated_examples, self.ALGORITHM_ID) ai_api.train_classifiers(RUBRIC, mutated_examples, self.ALGORITHM_ID)
def test_train_classifiers_no_examples(self):
# Empty list of training examples
with self.assertRaises(AITrainingRequestError):
ai_api.train_classifiers(RUBRIC, [], self.ALGORITHM_ID)
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS) @override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
@mock.patch.object(AITrainingWorkflow.objects, 'create') @mock.patch.object(AITrainingWorkflow.objects, 'create')
def test_start_workflow_database_error(self, mock_create): def test_start_workflow_database_error(self, mock_create):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment