Commit 0c84d255 by Brian Wilson Committed by Gerrit Code Review

Merge "Add SqoopImportTask and SqoopImportFromMysql."

parents 41cafc3b 84828de0
......@@ -69,7 +69,6 @@ class PathSetTask(luigi.Task):
else:
return self.generate_file_list()
def complete(self):
# An optimization: just declare that the task is always
# complete, by definition, because it is whatever files were
......
"""
Gather data using Sqoop table dumps run on RDBMS databases.
"""
import json
import luigi
import luigi.hadoop
import luigi.hdfs
import luigi.configuration
from edx.analytics.tasks.url import ExternalURL
from edx.analytics.tasks.url import get_target_from_url
from edx.analytics.tasks.url import url_path_join
def load_sqoop_cmd():
"""Get path to sqoop command from Luigi configuration."""
return luigi.configuration.get_config().get('sqoop', 'command', 'sqoop')
class SqoopImportTask(luigi.hadoop.BaseHadoopJobTask):
"""
An abstract task that uses Sqoop to read data out of a database and
writes it to a file in CSV format.
In order to protect the database access credentials they are
loaded from an external file which can be secured appropriately.
The credentials file is expected to be JSON formatted and contain
a simple map specifying the host, port, username password and
database.
Parameters:
credentials: Path to the external access credentials file.
destination: The directory to write the output files to.
table_name: The name of the table to import.
num_mappers: The number of map tasks to ask Sqoop to use.
where: A 'where' clause to be passed to Sqoop. Note that
no spaces should be embedded and special characters should
be escaped. For example: --where "id\<50".
Example Credentials File::
{
"host": "db.example.com",
"port": "3306",
"username": "exampleuser",
"password": "example password",
"database": "exampledata"
}
"""
# TODO: Defaults from config file
credentials = luigi.Parameter()
destination = luigi.Parameter()
table_name = luigi.Parameter()
num_mappers = luigi.Parameter(default=None)
where = luigi.Parameter(default=None)
def requires(self):
return {
'credentials': ExternalURL(url=self.credentials),
}
def output(self):
return get_target_from_url(url_path_join(self.destination, self.table_name))
def job_runner(self):
"""Use simple runner that gets args from the job and passes through."""
return SqoopImportRunner()
def get_arglist(self, password_file):
"""Returns list of arguments for running Sqoop."""
arglist = [load_sqoop_cmd(), 'import']
# Generic args should be passed to sqoop first, followed by import-specific args.
arglist.extend(self.generic_args(password_file))
arglist.extend(self.import_args())
return arglist
def generic_args(self, password_target):
"""Returns list of arguments used by all Sqoop commands, using credentials read from file."""
cred = self._get_credentials()
url = self.connection_url(cred)
generic_args = ['--connect', url, '--username', cred['username']]
# write password to temp file object, and pass name of file to Sqoop:
with password_target.open('w') as password_file:
password_file.write(cred['password'])
password_file.flush()
generic_args.extend(['--password-file', password_target.path])
return generic_args
def import_args(self):
"""Returns list of arguments specific to Sqoop import."""
arglist = ['--table', self.table_name, '--warehouse-dir', self.destination]
if self.num_mappers is not None:
arglist.extend(['--num-mappers', str(self.num_mappers)])
if self.where is not None:
arglist.extend(['--where', str(self.where)])
return arglist
def connection_url(self, _cred):
"""Construct connection URL from provided credentials."""
raise NotImplementedError # pragma: no cover
def _get_credentials(self):
"""
Gathers the secure connection parameters from an external file
and uses them to establish a connection to the database
specified in the secure parameters.
Returns:
A dict containing credentials.
"""
cred = {}
with self.input()['credentials'].open('r') as credentials_file:
cred = json.load(credentials_file)
return cred
class SqoopImportFromMysql(SqoopImportTask):
"""
An abstract task that uses Sqoop to read data out of a database and
writes it to a file in CSV format.
Output format is defined by meaning of --mysql-delimiters option,
which defines defaults used by mysqldump tool:
* fields delimited by comma
* lines delimited by \n
* delimiters escaped by backslash
* delimiters optionally enclosed by single quotes (')
"""
def connection_url(self, cred):
"""Construct connection URL from provided credentials."""
return 'jdbc:mysql://{host}/{database}'.format(**cred)
def import_args(self):
"""Returns list of arguments specific to Sqoop import from a Mysql database."""
arglist = super(SqoopImportFromMysql, self).import_args()
arglist.extend(['--direct', '--mysql-delimiters'])
return arglist
class SqoopImportRunner(luigi.hadoop.JobRunner):
"""Runs a SqoopImportTask by shelling out to sqoop."""
def run_job(self, job):
"""Runs a SqoopImportTask by shelling out to sqoop."""
# Create a temp file in HDFS to store the password,
# so it isn't echoed by the hadoop job code.
# It should be deleted when it goes out of scope
# (using __del__()), but safer to just make sure.
try:
password_target = luigi.hdfs.HdfsTarget(is_tmp=True)
arglist = job.get_arglist(password_target)
luigi.hadoop.run_and_track_hadoop_job(arglist)
finally:
password_target.remove()
"""Tests for Sqoop import task."""
import textwrap
from mock import MagicMock
from mock import patch
from mock import sentinel
from edx.analytics.tasks.sqoop import SqoopImportFromMysql
from edx.analytics.tasks.tests import unittest
from edx.analytics.tasks.tests.target import FakeTarget
class SqoopImportFromMysqlTestCase(unittest.TestCase):
"""
Ensure we can pass the right arguments to Sqoop.
"""
def setUp(self):
patcher = patch('luigi.hdfs.HdfsTarget')
self.mock_hdfstarget = patcher.start()
self.addCleanup(patcher.stop)
self.mock_hdfstarget().path = "/temp/password_file"
patcher2 = patch("luigi.hadoop.run_and_track_hadoop_job")
self.mock_run = patcher2.start()
self.addCleanup(patcher2.stop)
def run_task(self, credentials=None, num_mappers=None, where=None):
"""
Emulate execution of a generic MysqlTask.
"""
if not credentials:
credentials = '''\
{
"host": "db.example.com",
"port": "3306",
"username": "exampleuser",
"password": "example password",
"database": "exampledata"
}'''
task = SqoopImportFromMysql(
credentials=sentinel.ignored,
destination="/fake/destination",
table_name="example_table",
num_mappers=num_mappers,
where=where
)
fake_input = {
'credentials': FakeTarget(textwrap.dedent(credentials))
}
task.input = MagicMock(return_value=fake_input)
task.run()
arglist = self.mock_run.call_args[0][0]
return arglist
def test_connect_with_missing_credentials(self):
with self.assertRaises(KeyError):
self.run_task('{}')
self.assertTrue(self.mock_hdfstarget().remove.called)
self.assertFalse(self.mock_run.called)
def test_connect_with_credential_syntax_error(self):
with self.assertRaises(ValueError):
self.run_task('{')
self.assertTrue(self.mock_hdfstarget().remove.called)
self.assertFalse(self.mock_run.called)
def test_connect_with_complete_credentials(self):
arglist = self.run_task()
self.assertTrue(self.mock_run.called)
expected_arglist = [
'sqoop',
'import',
'--connect',
'jdbc:mysql://db.example.com/exampledata',
'--username',
'exampleuser',
'--password-file',
'/temp/password_file',
'--table',
'example_table',
'--warehouse-dir',
'/fake/destination',
'--direct',
'--mysql-delimiters'
]
self.assertEquals(arglist, expected_arglist)
self.assertTrue(self.mock_hdfstarget().remove.called)
def test_connect_with_where_args(self):
arglist = self.run_task(where='id < 50')
self.assertEquals(arglist[-4], '--where')
self.assertEquals(arglist[-3], 'id < 50')
def test_connect_with_num_mappers(self):
arglist = self.run_task(num_mappers=50)
self.assertEquals(arglist[-4], '--num-mappers')
self.assertEquals(arglist[-3], '50')
......@@ -29,6 +29,7 @@ edx.analytics.tasks =
inc-enrollments-report = edx.analytics.tasks.reports.incremental_enrollments:WeeklyIncrementalUsersAndEnrollments
course-enroll = edx.analytics.tasks.course_enroll:CourseEnrollmentChangesPerDay
answer_dist = edx.analytics.tasks.answer_dist:AnswerDistributionPerCourse
sqoop-import = edx.analytics.tasks.sqoop:SqoopImportFromMysql
mapreduce.engine =
hadoop = edx.analytics.tasks.mapreduce:MapReduceJobRunner
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment