Commit e237df58 by Will Daly

Merge pull request #446 from edx/will/ai-use-cpickle

Use cPickle instead of pickle; add additional tests
parents ec4094c4 4d8635ba
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
Tests for AI algorithm implementations. Tests for AI algorithm implementations.
""" """
import unittest import unittest
import json
import mock import mock
from openassessment.test_utils import CacheResetTest from openassessment.test_utils import CacheResetTest
from openassessment.assessment.worker.algorithm import ( from openassessment.assessment.worker.algorithm import (
...@@ -16,14 +17,20 @@ EXAMPLES = [ ...@@ -16,14 +17,20 @@ EXAMPLES = [
AIAlgorithm.ExampleEssay(u"How years ago in days of old, when 𝒎𝒂𝒈𝒊𝒄 filled th air.", 1), AIAlgorithm.ExampleEssay(u"How years ago in days of old, when 𝒎𝒂𝒈𝒊𝒄 filled th air.", 1),
AIAlgorithm.ExampleEssay(u"Ṫ'ẅäṡ in the darkest depths of Ṁöṛḋöṛ, I met a girl so fair.", 1), AIAlgorithm.ExampleEssay(u"Ṫ'ẅäṡ in the darkest depths of Ṁöṛḋöṛ, I met a girl so fair.", 1),
AIAlgorithm.ExampleEssay(u"But goレレuᄊ, and the evil one crept up and slipped away with her", 0), AIAlgorithm.ExampleEssay(u"But goレレuᄊ, and the evil one crept up and slipped away with her", 0),
AIAlgorithm.ExampleEssay(u"", 4) AIAlgorithm.ExampleEssay(u"", 4),
AIAlgorithm.ExampleEssay(u".!?", 4),
AIAlgorithm.ExampleEssay(u"no punctuation", 4),
AIAlgorithm.ExampleEssay(u"one", 4),
] ]
INPUT_ESSAYS = [ INPUT_ESSAYS = [
u"Good times, 𝑩𝒂𝒅 𝑻𝒊𝒎𝒆𝒔, you know I had my share", u"Good times, 𝑩𝒂𝒅 𝑻𝒊𝒎𝒆𝒔, you know I had my share",
u"When my woman left home for a 𝒃𝒓𝒐𝒘𝒏 𝒆𝒚𝒆𝒅 𝒎𝒂𝒏", u"When my woman left home for a 𝒃𝒓𝒐𝒘𝒏 𝒆𝒚𝒆𝒅 𝒎𝒂𝒏",
u"Well, I still don't seem to 𝒄𝒂𝒓𝒆", u"Well, I still don't seem to 𝒄𝒂𝒓𝒆",
u"" u"",
u".!?",
u"no punctuation",
u"one",
] ]
...@@ -62,7 +69,7 @@ class FakeAIAlgorithmTest(AIAlgorithmTest): ...@@ -62,7 +69,7 @@ class FakeAIAlgorithmTest(AIAlgorithmTest):
def test_train_and_score(self): def test_train_and_score(self):
classifier = self.algorithm.train_classifier(EXAMPLES) classifier = self.algorithm.train_classifier(EXAMPLES)
expected_scores = [2, 0, 0, 0] expected_scores = [2, 0, 0, 0, 4, 2, 4]
scores = self._scores(classifier, INPUT_ESSAYS) scores = self._scores(classifier, INPUT_ESSAYS)
self.assertEqual(scores, expected_scores) self.assertEqual(scores, expected_scores)
...@@ -112,22 +119,46 @@ class EaseAIAlgorithmTest(AIAlgorithmTest): ...@@ -112,22 +119,46 @@ class EaseAIAlgorithmTest(AIAlgorithmTest):
classifier = self.algorithm.train_classifier(examples) classifier = self.algorithm.train_classifier(examples)
self._scores(classifier, INPUT_ESSAYS) self._scores(classifier, INPUT_ESSAYS)
def test_most_examples_have_same_score(self):
# All training examples have the same score except for one
examples = [
AIAlgorithm.ExampleEssay(u"Test ëṡṡäÿ", 1),
AIAlgorithm.ExampleEssay(u"Another test ëṡṡäÿ", 1),
AIAlgorithm.ExampleEssay(u"Different score", 0),
]
classifier = self.algorithm.train_classifier(examples)
scores = self._scores(classifier, INPUT_ESSAYS)
# Check that we got scores back.
# This is not a very rigorous assertion -- we're mainly
# checking that we got this far without an exception.
self.assertEqual(len(scores), len(INPUT_ESSAYS))
def test_no_examples(self): def test_no_examples(self):
with self.assertRaises(TrainingError): with self.assertRaises(TrainingError):
self.algorithm.train_classifier([]) self.algorithm.train_classifier([])
def test_json_serializable(self):
classifier = self.algorithm.train_classifier(EXAMPLES)
serialized = json.dumps(classifier)
deserialized = json.loads(serialized)
# This should not raise an exception
scores = self._scores(deserialized, INPUT_ESSAYS)
self.assertEqual(len(scores), len(INPUT_ESSAYS))
@mock.patch('openassessment.assessment.worker.algorithm.pickle') @mock.patch('openassessment.assessment.worker.algorithm.pickle')
def test_pickle_serialize_error(self, mock_pickle): def test_pickle_serialize_error(self, mock_pickle):
mock_pickle.dumps.side_effect = Exception("Test error!") mock_pickle.dumps.side_effect = Exception("Test error!")
with self.assertRaises(TrainingError): with self.assertRaises(TrainingError):
self.algorithm.train_classifier(EXAMPLES) self.algorithm.train_classifier(EXAMPLES)
@mock.patch('openassessment.assessment.worker.algorithm.pickle') def test_pickle_deserialize_error(self):
def test_pickle_deserialize_error(self, mock_pickle):
mock_pickle.loads.side_effect = Exception("Test error!")
classifier = self.algorithm.train_classifier(EXAMPLES) classifier = self.algorithm.train_classifier(EXAMPLES)
with self.assertRaises(InvalidClassifier): with mock.patch('openassessment.assessment.worker.algorithm.pickle.loads') as mock_call:
self.algorithm.score(u"Test ëṡṡäÿ", classifier, {}) mock_call.side_effect = Exception("Test error!")
with self.assertRaises(InvalidClassifier):
self.algorithm.score(u"Test ëṡṡäÿ", classifier, {})
def test_serialized_classifier_not_a_dict(self): def test_serialized_classifier_not_a_dict(self):
with self.assertRaises(InvalidClassifier): with self.assertRaises(InvalidClassifier):
......
""" """
Define the ML algorithms used to train text classifiers. Define the ML algorithms used to train text classifiers.
""" """
try:
import cPickle as pickle
except ImportError:
import pickle
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
import importlib import importlib
import traceback import traceback
import pickle import base64
from django.conf import settings from django.conf import settings
...@@ -315,8 +320,8 @@ class EaseAIAlgorithm(AIAlgorithm): ...@@ -315,8 +320,8 @@ class EaseAIAlgorithm(AIAlgorithm):
""" """
try: try:
return { return {
'feature_extractor': pickle.dumps(feature_ext), 'feature_extractor': base64.b64encode(pickle.dumps(feature_ext)),
'score_classifier': pickle.dumps(classifier), 'score_classifier': base64.b64encode(pickle.dumps(classifier)),
} }
except Exception as ex: except Exception as ex:
msg = ( msg = (
...@@ -343,7 +348,8 @@ class EaseAIAlgorithm(AIAlgorithm): ...@@ -343,7 +348,8 @@ class EaseAIAlgorithm(AIAlgorithm):
raise InvalidClassifier("Classifier must be a dictionary.") raise InvalidClassifier("Classifier must be a dictionary.")
try: try:
feature_extractor = pickle.loads(classifier_data.get('feature_extractor')) classifier_str = classifier_data.get('feature_extractor').encode('utf-8')
feature_extractor = pickle.loads(base64.b64decode(classifier_str))
except Exception as ex: except Exception as ex:
msg = ( msg = (
u"An error occurred while deserializing the " u"An error occurred while deserializing the "
...@@ -352,7 +358,8 @@ class EaseAIAlgorithm(AIAlgorithm): ...@@ -352,7 +358,8 @@ class EaseAIAlgorithm(AIAlgorithm):
raise InvalidClassifier(msg) raise InvalidClassifier(msg)
try: try:
score_classifier = pickle.loads(classifier_data.get('score_classifier')) score_classifier_str = classifier_data.get('score_classifier').encode('utf-8')
score_classifier = pickle.loads(base64.b64decode(score_classifier_str))
except Exception as ex: except Exception as ex:
msg = ( msg = (
u"An error occurred while deserializing the " u"An error occurred while deserializing the "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment