diff --git a/luigi/contrib/sqla.py b/luigi/contrib/sqla.py index 4cc7c26f9c..72e35d3c16 100644 --- a/luigi/contrib/sqla.py +++ b/luigi/contrib/sqla.py @@ -273,6 +273,8 @@ class CopyToTable(luigi.Task): Usage: * subclass and override the required `connection_string`, `table` and `columns` attributes. + * optionally override the `schema` attribute to use a different schema for + the target table. """ _logger = logging.getLogger('luigi-interface') @@ -300,6 +302,11 @@ def table(self): # completely ignore the columns. Instead set the reflect value to True below columns = [] + # Specify the database schema of the target table, if supported by the + # RDBMS. Note that this doesn't change the schema of the marker table. + # The schema MUST already exist in the database, or this will task fail. + schema = '' + # options column_separator = "\t" # how columns are separated in the file copied into postgres chunk_size = 5000 # default chunk size for insert @@ -328,15 +335,21 @@ def construct_sqla_columns(columns): else: # if columns is specified as (name, type) tuples with engine.begin() as con: - metadata = sqlalchemy.MetaData() + + if self.schema: + metadata = sqlalchemy.MetaData(schema=self.schema) + else: + metadata = sqlalchemy.MetaData() + try: - if not con.dialect.has_table(con, self.table): + if not con.dialect.has_table(con, self.table, self.schema or None): sqla_columns = construct_sqla_columns(self.columns) self.table_bound = sqlalchemy.Table(self.table, metadata, *sqla_columns) metadata.create_all(engine) else: - metadata.reflect(only=[self.table], bind=engine) - self.table_bound = metadata.tables[self.table] + full_table = '.'.join([self.schema, self.table]) if self.schema else self.table + metadata.reflect(only=[full_table], bind=engine) + self.table_bound = metadata.tables[full_table] except Exception as e: self._logger.exception(self.table + str(e))