Commit 642687f8 by Gabe Mulley

Fix minor bugs blocking report generation

Change-Id: I5c1d8f476587362b5d20d028aeb6283c4c3b18b0
parent 8ad6b877
...@@ -5,8 +5,6 @@ from __future__ import absolute_import ...@@ -5,8 +5,6 @@ from __future__ import absolute_import
import luigi.hadoop import luigi.hadoop
from stevedore import ExtensionManager
class MapReduceJobTask(luigi.hadoop.JobTask): class MapReduceJobTask(luigi.hadoop.JobTask):
""" """
...@@ -19,6 +17,10 @@ class MapReduceJobTask(luigi.hadoop.JobTask): ...@@ -19,6 +17,10 @@ class MapReduceJobTask(luigi.hadoop.JobTask):
) )
def job_runner(self): def job_runner(self):
# Lazily import this since this module will be loaded on hadoop worker nodes however stevedore will not be
# available in that environment.
from stevedore import ExtensionManager
extension_manager = ExtensionManager('mapreduce.engine') extension_manager = ExtensionManager('mapreduce.engine')
try: try:
engine_class = extension_manager[self.mapreduce_engine].plugin engine_class = extension_manager[self.mapreduce_engine].plugin
......
...@@ -237,6 +237,23 @@ class TestWeeklyAllUsersAndEnrollments(unittest.TestCase): ...@@ -237,6 +237,23 @@ class TestWeeklyAllUsersAndEnrollments(unittest.TestCase):
self.assertEqual(res.loc[self.enrollment_label]['2013-01-08'], 4) self.assertEqual(res.loc[self.enrollment_label]['2013-01-08'], 4)
self.assertEqual(res.loc[self.enrollment_label]['2013-01-15'], 6) self.assertEqual(res.loc[self.enrollment_label]['2013-01-15'], 6)
def test_blacklist_course_not_in_enrollments(self):
enrollments = """
course_1 2013-01-02 1
course_2 2013-01-02 2
course_3 2013-01-02 4
course_2 2013-01-09 1
course_3 2013-01-15 2
"""
blacklist = """
course_4
course_1
course_2
"""
res = self.run_task('', enrollments, '2013-01-15', 2, blacklist=blacklist)
self.assertEqual(res.loc[self.enrollment_label]['2013-01-08'], 4)
self.assertEqual(res.loc[self.enrollment_label]['2013-01-15'], 6)
def test_unicode(self): def test_unicode(self):
course_id = u'course_\u2603' course_id = u'course_\u2603'
......
...@@ -98,8 +98,13 @@ class AllCourseEnrollmentCountMixin(CourseEnrollmentCountMixin): ...@@ -98,8 +98,13 @@ class AllCourseEnrollmentCountMixin(CourseEnrollmentCountMixin):
Returns: Returns:
None, the `course_data` is modified in place. None, the `course_data` is modified in place.
""" """
for course_id in course_blacklist:
try:
# Drop from axis 1 because we are dropping columns, not rows. # Drop from axis 1 because we are dropping columns, not rows.
course_data.drop(course_blacklist, axis=1, inplace=True) course_data.drop(course_id, axis=1, inplace=True)
except ValueError:
# There is no column for this course.
pass
def save_output(self, results, output_file): def save_output(self, results, output_file):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment