Commit 41caac06 by Will Daly

Train classifiers in the simulate_ai_grading_error management command

parent fe534349
...@@ -14,6 +14,7 @@ from django.core.management.base import BaseCommand, CommandError ...@@ -14,6 +14,7 @@ from django.core.management.base import BaseCommand, CommandError
from submissions import api as sub_api from submissions import api as sub_api
from openassessment.assessment.models import AIGradingWorkflow, AIClassifierSet from openassessment.assessment.models import AIGradingWorkflow, AIClassifierSet
from openassessment.assessment.serializers import rubric_from_dict from openassessment.assessment.serializers import rubric_from_dict
from openassessment.assessment.worker.algorithm import AIAlgorithm
class Command(BaseCommand): class Command(BaseCommand):
...@@ -26,7 +27,7 @@ class Command(BaseCommand): ...@@ -26,7 +27,7 @@ class Command(BaseCommand):
u"by creating incomplete AI grading workflows in the database." u"by creating incomplete AI grading workflows in the database."
) )
args = '<COURSE_ID> <PROBLEM_ID> <NUM_SUBMISSIONS>' args = '<COURSE_ID> <PROBLEM_ID> <NUM_SUBMISSIONS> <ALGORITHM_ID>'
RUBRIC_OPTIONS = [ RUBRIC_OPTIONS = [
{ {
...@@ -61,15 +62,34 @@ class Command(BaseCommand): ...@@ -61,15 +62,34 @@ class Command(BaseCommand):
] ]
} }
# Since we're not actually running an AI scoring algorithm, EXAMPLES = {
# we can use dummy data for the classifier, as long as it's "vocabulary": [
# JSON-serializable. AIAlgorithm.ExampleEssay(
CLASSIFIERS = { text=u"World Food Day is celebrated every year around the world on 16 October in honor "
u'vocabulary': {}, u"of the date of the founding of the Food and Agriculture "
u'grammar': {} u"Organization of the United Nations in 1945.",
score=0
),
AIAlgorithm.ExampleEssay(
text=u"Since 1981, World Food Day has adopted a different theme each year "
u"in order to highlight areas needed for action and provide a common focus.",
score=1
),
],
"grammar": [
AIAlgorithm.ExampleEssay(
text=u"Most of the themes revolve around agriculture because only investment in agriculture ",
score=0
),
AIAlgorithm.ExampleEssay(
text=u"In spite of the importance of agriculture as the driving force "
u"in the economies of many developing countries, this "
u"vital sector is frequently starved of investment.",
score=1
)
]
} }
ALGORITHM_ID = u'fake'
STUDENT_ID = u'test_student' STUDENT_ID = u'test_student'
ANSWER = {'answer': 'test answer'} ANSWER = {'answer': 'test answer'}
...@@ -81,26 +101,38 @@ class Command(BaseCommand): ...@@ -81,26 +101,38 @@ class Command(BaseCommand):
course_id (unicode): The ID of the course to create submissions/workflows in. course_id (unicode): The ID of the course to create submissions/workflows in.
item_id (unicode): The ID of the problem in the course. item_id (unicode): The ID of the problem in the course.
num_submissions (int): The number of submissions/workflows to create. num_submissions (int): The number of submissions/workflows to create.
algorithm_id (unicode): The ID of the ML algorithm to use ("fake" or "ease")
Raises: Raises:
CommandError CommandError
""" """
if len(args) < 3: if len(args) < 4:
raise CommandError(u"Usage: simulate_ai_grading_error {}".format(self.args)) raise CommandError(u"Usage: simulate_ai_grading_error {}".format(self.args))
# Parse arguments # Parse arguments
course_id = args[0].decode('utf-8') course_id = args[0].decode('utf-8')
item_id = args[1].decode('utf-8') item_id = args[1].decode('utf-8')
num_submissions = int(args[2]) num_submissions = int(args[2])
algorithm_id = args[3].decode('utf-8')
# Create the rubric model # Create the rubric model
rubric = rubric_from_dict(self.RUBRIC) rubric = rubric_from_dict(self.RUBRIC)
# Train classifiers
print u"Training classifiers using {algorithm_id}...".format(algorithm_id=algorithm_id)
algorithm = AIAlgorithm.algorithm_for_id(algorithm_id)
classifier_data = {
criterion_name: algorithm.train_classifier(example)
for criterion_name, example in self.EXAMPLES.iteritems()
}
print u"Successfully trained classifiers."
# Create the classifier set # Create the classifier set
classifier_set = AIClassifierSet.create_classifier_set( classifier_set = AIClassifierSet.create_classifier_set(
self.CLASSIFIERS, rubric, self.ALGORITHM_ID classifier_data, rubric, algorithm_id
) )
print u"Successfully created classifier set with id {}".format(classifier_set.pk)
# Create submissions and grading workflows # Create submissions and grading workflows
for num in range(num_submissions): for num in range(num_submissions):
...@@ -112,7 +144,7 @@ class Command(BaseCommand): ...@@ -112,7 +144,7 @@ class Command(BaseCommand):
} }
submission = sub_api.create_submission(student_item, self.ANSWER) submission = sub_api.create_submission(student_item, self.ANSWER)
workflow = AIGradingWorkflow.start_workflow( workflow = AIGradingWorkflow.start_workflow(
submission['uuid'], self.RUBRIC, self.ALGORITHM_ID submission['uuid'], self.RUBRIC, algorithm_id
) )
workflow.classifier_set = classifier_set workflow.classifier_set = classifier_set
workflow.save() workflow.save()
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
Tests for the simulate AI grading error management command. Tests for the simulate AI grading error management command.
""" """
from django.test.utils import override_settings
from openassessment.test_utils import CacheResetTest from openassessment.test_utils import CacheResetTest
from openassessment.management.commands import simulate_ai_grading_error from openassessment.management.commands import simulate_ai_grading_error
from openassessment.assessment.models import AIGradingWorkflow from openassessment.assessment.models import AIGradingWorkflow
from openassessment.assessment.worker.grading import grade_essay
class SimulateAIGradingErrorTest(CacheResetTest): class SimulateAIGradingErrorTest(CacheResetTest):
...@@ -17,14 +19,19 @@ class SimulateAIGradingErrorTest(CacheResetTest): ...@@ -17,14 +19,19 @@ class SimulateAIGradingErrorTest(CacheResetTest):
ITEM_ID = u"𝖙𝖊𝖘𝖙 𝖎𝖙𝖊𝖒" ITEM_ID = u"𝖙𝖊𝖘𝖙 𝖎𝖙𝖊𝖒"
NUM_SUBMISSIONS = 20 NUM_SUBMISSIONS = 20
AI_ALGORITHMS = {
"fake": "openassessment.assessment.worker.algorithm.FakeAIAlgorithm"
}
@override_settings(ORA2_AI_ALGORITHMS=AI_ALGORITHMS)
def test_simulate_ai_grading_error(self): def test_simulate_ai_grading_error(self):
# Run the command # Run the command
cmd = simulate_ai_grading_error.Command() cmd = simulate_ai_grading_error.Command()
cmd.handle( cmd.handle(
self.COURSE_ID.encode('utf-8'), self.COURSE_ID.encode('utf-8'),
self.ITEM_ID.encode('utf-8'), self.ITEM_ID.encode('utf-8'),
self.NUM_SUBMISSIONS self.NUM_SUBMISSIONS,
"fake"
) )
# Check that the correct number of incomplete workflows # Check that the correct number of incomplete workflows
...@@ -33,8 +40,24 @@ class SimulateAIGradingErrorTest(CacheResetTest): ...@@ -33,8 +40,24 @@ class SimulateAIGradingErrorTest(CacheResetTest):
# wouldn't have been scheduled for grading # wouldn't have been scheduled for grading
# (that is, the submissions were made before classifier # (that is, the submissions were made before classifier
# training completed). # training completed).
num_errors = AIGradingWorkflow.objects.filter( incomplete_workflows = AIGradingWorkflow.objects.filter(
classifier_set__isnull=False, classifier_set__isnull=False,
completed_at__isnull=True completed_at__isnull=True
).count() )
num_errors = incomplete_workflows.count()
self.assertEqual(self.NUM_SUBMISSIONS, num_errors) self.assertEqual(self.NUM_SUBMISSIONS, num_errors)
# Verify that we can complete the workflows successfully
# (that is, make sure the classifier data is valid)
# We're calling a Celery task method here,
# but we're NOT using `apply_async`, so this will
# execute synchronously.
for workflow in incomplete_workflows:
grade_essay(workflow.uuid)
# Now there should be no incomplete workflows
remaining_incomplete = AIGradingWorkflow.objects.filter(
classifier_set__isnull=False,
completed_at__isnull=True
).count()
self.assertEqual(remaining_incomplete, 0)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment