Commit e237df58 by Will Daly

Merge pull request #446 from edx/will/ai-use-cpickle

Use cPickle instead of pickle; add additional tests
parents ec4094c4 4d8635ba
......@@ -3,6 +3,7 @@
Tests for AI algorithm implementations.
"""
import unittest
import json
import mock
from openassessment.test_utils import CacheResetTest
from openassessment.assessment.worker.algorithm import (
......@@ -16,14 +17,20 @@ EXAMPLES = [
AIAlgorithm.ExampleEssay(u"How years ago in days of old, when 𝒎𝒂𝒈𝒊𝒄 filled th air.", 1),
AIAlgorithm.ExampleEssay(u"Ṫ'ẅäṡ in the darkest depths of Ṁöṛḋöṛ, I met a girl so fair.", 1),
AIAlgorithm.ExampleEssay(u"But goレレuᄊ, and the evil one crept up and slipped away with her", 0),
AIAlgorithm.ExampleEssay(u"", 4)
AIAlgorithm.ExampleEssay(u"", 4),
AIAlgorithm.ExampleEssay(u".!?", 4),
AIAlgorithm.ExampleEssay(u"no punctuation", 4),
AIAlgorithm.ExampleEssay(u"one", 4),
]
INPUT_ESSAYS = [
u"Good times, 𝑩𝒂𝒅 𝑻𝒊𝒎𝒆𝒔, you know I had my share",
u"When my woman left home for a 𝒃𝒓𝒐𝒘𝒏 𝒆𝒚𝒆𝒅 𝒎𝒂𝒏",
u"Well, I still don't seem to 𝒄𝒂𝒓𝒆",
u""
u"",
u".!?",
u"no punctuation",
u"one",
]
......@@ -62,7 +69,7 @@ class FakeAIAlgorithmTest(AIAlgorithmTest):
def test_train_and_score(self):
classifier = self.algorithm.train_classifier(EXAMPLES)
expected_scores = [2, 0, 0, 0]
expected_scores = [2, 0, 0, 0, 4, 2, 4]
scores = self._scores(classifier, INPUT_ESSAYS)
self.assertEqual(scores, expected_scores)
......@@ -112,22 +119,46 @@ class EaseAIAlgorithmTest(AIAlgorithmTest):
classifier = self.algorithm.train_classifier(examples)
self._scores(classifier, INPUT_ESSAYS)
def test_most_examples_have_same_score(self):
# All training examples have the same score except for one
examples = [
AIAlgorithm.ExampleEssay(u"Test ëṡṡäÿ", 1),
AIAlgorithm.ExampleEssay(u"Another test ëṡṡäÿ", 1),
AIAlgorithm.ExampleEssay(u"Different score", 0),
]
classifier = self.algorithm.train_classifier(examples)
scores = self._scores(classifier, INPUT_ESSAYS)
# Check that we got scores back.
# This is not a very rigorous assertion -- we're mainly
# checking that we got this far without an exception.
self.assertEqual(len(scores), len(INPUT_ESSAYS))
def test_no_examples(self):
with self.assertRaises(TrainingError):
self.algorithm.train_classifier([])
def test_json_serializable(self):
classifier = self.algorithm.train_classifier(EXAMPLES)
serialized = json.dumps(classifier)
deserialized = json.loads(serialized)
# This should not raise an exception
scores = self._scores(deserialized, INPUT_ESSAYS)
self.assertEqual(len(scores), len(INPUT_ESSAYS))
@mock.patch('openassessment.assessment.worker.algorithm.pickle')
def test_pickle_serialize_error(self, mock_pickle):
mock_pickle.dumps.side_effect = Exception("Test error!")
with self.assertRaises(TrainingError):
self.algorithm.train_classifier(EXAMPLES)
@mock.patch('openassessment.assessment.worker.algorithm.pickle')
def test_pickle_deserialize_error(self, mock_pickle):
mock_pickle.loads.side_effect = Exception("Test error!")
def test_pickle_deserialize_error(self):
classifier = self.algorithm.train_classifier(EXAMPLES)
with self.assertRaises(InvalidClassifier):
self.algorithm.score(u"Test ëṡṡäÿ", classifier, {})
with mock.patch('openassessment.assessment.worker.algorithm.pickle.loads') as mock_call:
mock_call.side_effect = Exception("Test error!")
with self.assertRaises(InvalidClassifier):
self.algorithm.score(u"Test ëṡṡäÿ", classifier, {})
def test_serialized_classifier_not_a_dict(self):
with self.assertRaises(InvalidClassifier):
......
"""
Define the ML algorithms used to train text classifiers.
"""
try:
import cPickle as pickle
except ImportError:
import pickle
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import importlib
import traceback
import pickle
import base64
from django.conf import settings
......@@ -315,8 +320,8 @@ class EaseAIAlgorithm(AIAlgorithm):
"""
try:
return {
'feature_extractor': pickle.dumps(feature_ext),
'score_classifier': pickle.dumps(classifier),
'feature_extractor': base64.b64encode(pickle.dumps(feature_ext)),
'score_classifier': base64.b64encode(pickle.dumps(classifier)),
}
except Exception as ex:
msg = (
......@@ -343,7 +348,8 @@ class EaseAIAlgorithm(AIAlgorithm):
raise InvalidClassifier("Classifier must be a dictionary.")
try:
feature_extractor = pickle.loads(classifier_data.get('feature_extractor'))
classifier_str = classifier_data.get('feature_extractor').encode('utf-8')
feature_extractor = pickle.loads(base64.b64decode(classifier_str))
except Exception as ex:
msg = (
u"An error occurred while deserializing the "
......@@ -352,7 +358,8 @@ class EaseAIAlgorithm(AIAlgorithm):
raise InvalidClassifier(msg)
try:
score_classifier = pickle.loads(classifier_data.get('score_classifier'))
score_classifier_str = classifier_data.get('score_classifier').encode('utf-8')
score_classifier = pickle.loads(base64.b64decode(score_classifier_str))
except Exception as ex:
msg = (
u"An error occurred while deserializing the "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment