Commit 973a0a17 by James Rowan

Merge pull request #140 from edx/jamesrowan/patch-vertica-load

Merged #140 
parents c4917a60 fb19697b
"""Collect the course catalog from the course catalog API for processing of course metadata like subjects or types""" """Collect the course catalog from the course catalog API for processing of course metadata like subjects or types."""
import requests import requests
import datetime import datetime
...@@ -140,7 +140,7 @@ class DailyLoadSubjectsToVerticaTask(PullCatalogMixin, VerticaCopyTask): ...@@ -140,7 +140,7 @@ class DailyLoadSubjectsToVerticaTask(PullCatalogMixin, VerticaCopyTask):
@property @property
def auto_primary_key(self): def auto_primary_key(self):
"""Overridden since the database schema specifies a different name for the auto incrementing primary key.""" """Overridden since the database schema specifies a different name for the auto incrementing primary key."""
return None return ('row_number', 'AUTO_INCREMENT')
@property @property
def default_columns(self): def default_columns(self):
...@@ -150,7 +150,6 @@ class DailyLoadSubjectsToVerticaTask(PullCatalogMixin, VerticaCopyTask): ...@@ -150,7 +150,6 @@ class DailyLoadSubjectsToVerticaTask(PullCatalogMixin, VerticaCopyTask):
@property @property
def columns(self): def columns(self):
return [ return [
('row_number', 'AUTO_INCREMENT PRIMARY KEY'),
('course_id', 'VARCHAR(200)'), ('course_id', 'VARCHAR(200)'),
('date', 'DATE'), ('date', 'DATE'),
('subject_uri', 'VARCHAR(200)'), ('subject_uri', 'VARCHAR(200)'),
......
...@@ -64,7 +64,7 @@ class VerticaCopyTaskTest(unittest.TestCase): ...@@ -64,7 +64,7 @@ class VerticaCopyTaskTest(unittest.TestCase):
def create_task(self, credentials=None, source=None, overwrite=False, cls=CopyToVerticaDummyTable): def create_task(self, credentials=None, source=None, overwrite=False, cls=CopyToVerticaDummyTable):
""" """
Emulate execution of a generic VerticaTask. Emulate execution of a generic VerticaCopyTask.
""" """
# Make sure to flush the instance cache so we create # Make sure to flush the instance cache so we create
# a new task object. # a new task object.
...@@ -129,7 +129,7 @@ class VerticaCopyTaskTest(unittest.TestCase): ...@@ -129,7 +129,7 @@ class VerticaCopyTaskTest(unittest.TestCase):
self.create_task().create_table(connection) self.create_task().create_table(connection)
connection.cursor().execute.assert_called_once_with( connection.cursor().execute.assert_called_once_with(
"CREATE TABLE IF NOT EXISTS testing.dummy_table " "CREATE TABLE IF NOT EXISTS testing.dummy_table "
"(id AUTO_INCREMENT PRIMARY KEY,course_id VARCHAR(255)," "(id AUTO_INCREMENT,course_id VARCHAR(255),"
"interval_start DATETIME,interval_end DATETIME,label VARCHAR(255)," "interval_start DATETIME,interval_end DATETIME,label VARCHAR(255),"
"count INT,created TIMESTAMP DEFAULT NOW(),PRIMARY KEY (id))" "count INT,created TIMESTAMP DEFAULT NOW(),PRIMARY KEY (id))"
) )
...@@ -155,7 +155,8 @@ class VerticaCopyTaskTest(unittest.TestCase): ...@@ -155,7 +155,8 @@ class VerticaCopyTaskTest(unittest.TestCase):
def _get_expected_query(self): def _get_expected_query(self):
"""Returns query that should be generated for copying into the table.""" """Returns query that should be generated for copying into the table."""
query = ("COPY {schema}.dummy_table FROM STDIN DELIMITER AS E'\t' NULL AS '\\N' DIRECT NO COMMIT;" query = ("COPY {schema}.dummy_table (course_id,interval_start,interval_end,label,count) "
"FROM STDIN DELIMITER AS E'\t' NULL AS '\\N' DIRECT NO COMMIT;"
.format(schema=self.create_task().schema)) .format(schema=self.create_task().schema))
return query return query
...@@ -174,39 +175,36 @@ class VerticaCopyTaskTest(unittest.TestCase): ...@@ -174,39 +175,36 @@ class VerticaCopyTaskTest(unittest.TestCase):
task = self.create_task(source=self._get_source_string(1)) task = self.create_task(source=self._get_source_string(1))
cursor = MagicMock() cursor = MagicMock()
task.copy_data_table_from_target(cursor) task.copy_data_table_from_target(cursor)
query = cursor.copy_file.call_args[0][0] query = cursor.copy_stream.call_args[0][0]
self.assertEquals(query, self._get_expected_query()) self.assertEquals(query, self._get_expected_query())
file_to_copy = cursor.copy_file.call_args[0][1] file_to_copy = cursor.copy_stream.call_args[0][1]
with task.input()['insert_source'].open('r') as expected_data: with task.input()['insert_source'].open('r') as expected_data:
expected_source = expected_data.read() expected_source = expected_data.read()
with file_to_copy as sent_data: sent_source = file_to_copy.read()
sent_source = sent_data.read()
self.assertEquals(sent_source, expected_source) self.assertEquals(sent_source, expected_source)
def test_copy_multiple_rows(self): def test_copy_multiple_rows(self):
task = self.create_task(source=self._get_source_string(4)) task = self.create_task(source=self._get_source_string(4))
cursor = MagicMock() cursor = MagicMock()
task.copy_data_table_from_target(cursor) task.copy_data_table_from_target(cursor)
query = cursor.copy_file.call_args[0][0] query = cursor.copy_stream.call_args[0][0]
self.assertEquals(query, self._get_expected_query()) self.assertEquals(query, self._get_expected_query())
file_to_copy = cursor.copy_file.call_args[0][1] file_to_copy = cursor.copy_stream.call_args[0][1]
with task.input()['insert_source'].open('r') as expected_data: with task.input()['insert_source'].open('r') as expected_data:
expected_source = expected_data.read() expected_source = expected_data.read()
with file_to_copy as sent_data: sent_source = file_to_copy.read()
sent_source = sent_data.read()
self.assertEquals(sent_source, expected_source) self.assertEquals(sent_source, expected_source)
def test_copy_to_predefined_table(self): def test_copy_to_predefined_table(self):
task = self.create_task(cls=CopyToPredefinedVerticaDummyTable) task = self.create_task(cls=CopyToPredefinedVerticaDummyTable)
cursor = MagicMock() cursor = MagicMock()
task.copy_data_table_from_target(cursor) task.copy_data_table_from_target(cursor)
query = cursor.copy_file.call_args[0][0] query = cursor.copy_stream.call_args[0][0]
self.assertEquals(query, self._get_expected_query()) self.assertEquals(query, self._get_expected_query())
file_to_copy = cursor.copy_file.call_args[0][1] file_to_copy = cursor.copy_stream.call_args[0][1]
with task.input()['insert_source'].open('r') as expected_data: with task.input()['insert_source'].open('r') as expected_data:
expected_source = expected_data.read() expected_source = expected_data.read()
with file_to_copy as sent_data: sent_source = file_to_copy.read()
sent_source = sent_data.read()
self.assertEquals(sent_source, expected_source) self.assertEquals(sent_source, expected_source)
@with_luigi_config(('vertica-export', 'schema', 'foobar')) @with_luigi_config(('vertica-export', 'schema', 'foobar'))
...@@ -219,7 +217,7 @@ class VerticaCopyTaskTest(unittest.TestCase): ...@@ -219,7 +217,7 @@ class VerticaCopyTaskTest(unittest.TestCase):
call("CREATE SCHEMA IF NOT EXISTS foobar"), call("CREATE SCHEMA IF NOT EXISTS foobar"),
call( call(
"CREATE TABLE IF NOT EXISTS foobar.dummy_table " "CREATE TABLE IF NOT EXISTS foobar.dummy_table "
"(id AUTO_INCREMENT PRIMARY KEY,course_id VARCHAR(255)," "(id AUTO_INCREMENT,course_id VARCHAR(255),"
"interval_start DATETIME,interval_end DATETIME,label VARCHAR(255)," "interval_start DATETIME,interval_end DATETIME,label VARCHAR(255),"
"count INT,created TIMESTAMP DEFAULT NOW(),PRIMARY KEY (id))" "count INT,created TIMESTAMP DEFAULT NOW(),PRIMARY KEY (id))"
) )
......
...@@ -29,6 +29,7 @@ class VerticaCopyTaskMixin(OverwriteOutputMixin): ...@@ -29,6 +29,7 @@ class VerticaCopyTaskMixin(OverwriteOutputMixin):
credentials: Path to the external access credentials file. credentials: Path to the external access credentials file.
schema: The schema to which to write. schema: The schema to which to write.
insert_chunk_size: The number of rows to insert at a time.
""" """
schema = luigi.Parameter( schema = luigi.Parameter(
config_path={'section': 'vertica-export', 'name': 'schema'} config_path={'section': 'vertica-export', 'name': 'schema'}
...@@ -41,7 +42,6 @@ class VerticaCopyTaskMixin(OverwriteOutputMixin): ...@@ -41,7 +42,6 @@ class VerticaCopyTaskMixin(OverwriteOutputMixin):
class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
""" """
A task for copying into a Vertica database. A task for copying into a Vertica database.
""" """
required_tasks = None required_tasks = None
output_target = None output_target = None
...@@ -78,7 +78,12 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -78,7 +78,12 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
@property @property
def auto_primary_key(self): def auto_primary_key(self):
"""Tuple defining name and definition of an auto-incrementing primary key, or None.""" """Tuple defining name and definition of an auto-incrementing primary key, or None."""
return ('id', 'AUTO_INCREMENT PRIMARY KEY') return ('id', 'AUTO_INCREMENT')
@property
def foreign_key_mapping(self):
"""Dictionary of column_name: (schema.table, column) pairs representing foreign key constraints."""
return {}
@property @property
def default_columns(self): def default_columns(self):
...@@ -99,6 +104,28 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -99,6 +104,28 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
log.debug(query) log.debug(query)
connection.cursor().execute(query) connection.cursor().execute(query)
def create_column_definitions(self):
"""
Builds the list of column definitions for the table to be loaded.
Assumes that columns are specified as (name, definition) tuples.
:return a string to be used in a SQL query to create the table
"""
columns = []
if self.auto_primary_key is not None:
columns.append(self.auto_primary_key)
columns.extend(self.columns)
if self.default_columns is not None:
columns.extend(self.default_columns)
if self.auto_primary_key is not None:
columns.append(("PRIMARY KEY", "({name})".format(name=self.auto_primary_key[0])))
coldefs = ','.join(
'{name} {definition}'.format(name=name, definition=definition) for name, definition in columns
)
return coldefs
def create_table(self, connection): def create_table(self, connection):
""" """
Override to provide code for creating the target table, if not existing. Override to provide code for creating the target table, if not existing.
...@@ -106,7 +133,6 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -106,7 +133,6 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
Requires the schema to exist first. Requires the schema to exist first.
By default it will be created using types (optionally) specified in columns. By default it will be created using types (optionally) specified in columns.
If overridden, use the provided connection object for setting If overridden, use the provided connection object for setting
up the table in order to create the table and insert data up the table in order to create the table and insert data
using the same transaction. using the same transaction.
...@@ -120,20 +146,17 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -120,20 +146,17 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
) )
# Assumes that columns are specified as (name, definition) tuples # Assumes that columns are specified as (name, definition) tuples
columns = [] coldefs = self.create_column_definitions()
if self.auto_primary_key is not None:
columns.append(self.auto_primary_key)
columns.extend(self.columns)
if self.default_columns is not None:
columns.extend(self.default_columns)
if self.auto_primary_key is not None:
columns.append(("PRIMARY KEY", "({name})".format(name=self.auto_primary_key[0])))
coldefs = ','.join( foreign_key_defs = ''
'{name} {definition}'.format(name=name, definition=definition) for name, definition in columns for column in self.foreign_key_mapping:
) foreign_key_defs += ", FOREIGN KEY ({col}) REFERENCES {other_schema_and_table} ({other_col})".format(
query = "CREATE TABLE IF NOT EXISTS {schema}.{table} ({coldefs})".format( col=column, other_schema_and_table=self.foreign_key_mapping[column][0],
schema=self.schema, table=self.table, coldefs=coldefs other_col=self.foreign_key_mapping[column][1]
)
query = "CREATE TABLE IF NOT EXISTS {schema}.{table} ({coldefs}{foreign_key_defs})".format(
schema=self.schema, table=self.table, coldefs=coldefs, foreign_key_defs=foreign_key_defs
) )
log.debug(query) log.debug(query)
connection.cursor().execute(query) connection.cursor().execute(query)
...@@ -176,10 +199,12 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -176,10 +199,12 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
# first clear the appropriate rows from the luigi Vertica marker table # first clear the appropriate rows from the luigi Vertica marker table
marker_table = self.output().marker_table # side-effect: sets self.output_target if it's None marker_table = self.output().marker_table # side-effect: sets self.output_target if it's None
try: try:
query = "DELETE FROM {marker_table} where `target_table`='{target_table}'".format( query = "DELETE FROM {schema}.{marker_table} where target_table='{schema}.{target_table}';".format(
schema=self.schema,
marker_table=marker_table, marker_table=marker_table,
target_table=self.table, target_table=self.table,
) )
log.debug(query)
connection.cursor().execute(query) connection.cursor().execute(query)
except vertica_python.errors.Error as err: except vertica_python.errors.Error as err:
if (type(err) is vertica_python.errors.MissingRelation) or ('Sqlstate: 42V01' in err.args[0]): if (type(err) is vertica_python.errors.MissingRelation) or ('Sqlstate: 42V01' in err.args[0]):
...@@ -191,8 +216,17 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -191,8 +216,17 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
# Use "DELETE" instead of TRUNCATE since TRUNCATE forces an implicit commit before it executes which would # Use "DELETE" instead of TRUNCATE since TRUNCATE forces an implicit commit before it executes which would
# commit the currently open transaction before continuing with the copy. # commit the currently open transaction before continuing with the copy.
query = "DELETE FROM {schema}.{table}".format(schema=self.schema, table=self.table) query = "DELETE FROM {schema}.{table}".format(schema=self.schema, table=self.table)
log.debug(query)
connection.cursor().execute(query) connection.cursor().execute(query)
# vertica-python and its maintainers intentionally avoid supporting open
# transactions like we do when self.overwrite=True (DELETE a bunch of rows
# and then COPY some), per https://github.com/uber/vertica-python/issues/56.
# The DELETE commands in this method will cause the connection to see some
# messages that will prevent it from trying to copy any data (if the cursor
# successfully executes the DELETEs), so we flush the message buffer.
connection.cursor().flush_to_query_ready()
@property @property
def copy_delimiter(self): def copy_delimiter(self):
"""The delimiter in the data to be copied. Default is tab (\t)""" """The delimiter in the data to be copied. Default is tab (\t)"""
...@@ -205,10 +239,27 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -205,10 +239,27 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
def copy_data_table_from_target(self, cursor): def copy_data_table_from_target(self, cursor):
"""Performs the copy query from the insert source.""" """Performs the copy query from the insert source."""
cursor.copy_file("COPY {schema}.{table} FROM STDIN DELIMITER AS {delim} NULL AS {null} DIRECT NO COMMIT;" if isinstance(self.columns[0], basestring):
.format(schema=self.schema, table=self.table, delim=self.copy_delimiter, column_names = ','.join([name for name in self.columns])
null=self.copy_null_sequence), elif len(self.columns[0]) == 2:
self.input()['insert_source'].open('r'), decoder='utf-8') column_names = ','.join([name for name, _type in self.columns])
else:
raise Exception('columns must consist of column strings or '
'(column string, type string) tuples (was %r ...)'
% (self.columns[0],))
with self.input()['insert_source'].open('r') as insert_source_file:
log.debug("Running copy_stream from source file")
cursor.copy_stream(
"COPY {schema}.{table} ({cols}) FROM STDIN DELIMITER AS {delim} NULL AS {null} DIRECT NO COMMIT;".format(
schema=self.schema,
table=self.table,
cols=column_names,
delim=self.copy_delimiter,
null=self.copy_null_sequence
),
insert_source_file
)
def run(self): def run(self):
""" """
...@@ -236,7 +287,9 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task): ...@@ -236,7 +287,9 @@ class VerticaCopyTask(VerticaCopyTaskMixin, luigi.Task):
# We commit only if both operations completed successfully. # We commit only if both operations completed successfully.
connection.commit() connection.commit()
except Exception: log.debug("Committed transaction.")
except Exception as exc:
log.debug("Rolled back the transaction; exception raised: %s", str(exc))
connection.rollback() connection.rollback()
raise raise
finally: finally:
...@@ -253,13 +306,13 @@ class CredentialFileVerticaTarget(VerticaTarget): ...@@ -253,13 +306,13 @@ class CredentialFileVerticaTarget(VerticaTarget):
Represents a table in Vertica, is complete when the update_id is the same as a previous successful execution. Represents a table in Vertica, is complete when the update_id is the same as a previous successful execution.
Arguments: Arguments:
credentials_target (luigi.Target): A target that can be read to retrieve the hostname, port and user credentials credentials_target (luigi.Target): A target that can be read to retrieve the hostname, port and user credentials
that will be used to connect to the database. that will be used to connect to the database.
database_name (str): The name of the database that the table exists in. Note this database need not exist. database_name (str): The name of the database that the table exists in. Note this database need not exist.
table (str): The name of the table in the database that is being modified. table (str): The name of the table in the database that is being modified.
update_id (str): A unique identifier for this update to the table. Subsequent updates with identical update_id update_id (str): A unique identifier for this update to the table. Subsequent updates with identical update_id
values will not be executed. values will not be executed.
""" """
def __init__(self, credentials_target, schema, table, update_id): def __init__(self, credentials_target, schema, table, update_id):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment