diff --git a/impala/sqlalchemy.py b/impala/sqlalchemy.py index f04d83283..9cf6b5467 100644 --- a/impala/sqlalchemy.py +++ b/impala/sqlalchemy.py @@ -19,8 +19,10 @@ import re from sqlalchemy.dialects import registry + from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext -from sqlalchemy.sql.compiler import IdentifierPreparer, GenericTypeCompiler +from sqlalchemy.sql.compiler import (DDLCompiler, GenericTypeCompiler, + IdentifierPreparer) from sqlalchemy.types import (BOOLEAN, SMALLINT, BIGINT, TIMESTAMP, FLOAT, DECIMAL, Integer, Float, String) @@ -44,6 +46,26 @@ class STRING(String): __visit_name__ = 'STRING' +class ImpalaDDLCompiler(DDLCompiler): + def post_create_table(self, table): + """Build table-level CREATE options.""" + + table_opts = [] + + if 'impala_partition_by' in table.kwargs: + table_opts.append('PARTITION BY %s' % table.kwargs.get('impala_partition_by')) + + if 'impala_stored_as' in table.kwargs: + table_opts.append('STORED AS %s' % table.kwargs.get('impala_stored_as')) + + if 'impala_table_properties' in table.kwargs: + table_properties = ["'{0}' = '{1}'".format(property_, value) + for property_, value + in table.kwargs.get('impala_table_properties', {}).items()] + table_opts.append('TBLPROPERTIES (%s)' % ', '.join(table_properties)) + return '\n%s' % '\n'.join(table_opts) + + class ImpalaTypeCompiler(GenericTypeCompiler): # pylint: disable=unused-argument @@ -129,6 +151,7 @@ class ImpalaDialect(DefaultDialect): supports_native_enum = False supports_default_values = False returns_unicode_strings = True + ddl_compiler = ImpalaDDLCompiler type_compiler = ImpalaTypeCompiler execution_ctx_cls = ImpalaExecutionContext diff --git a/impala/tests/test_sqlalchemy.py b/impala/tests/test_sqlalchemy.py index df2b446b4..349804ec5 100644 --- a/impala/tests/test_sqlalchemy.py +++ b/impala/tests/test_sqlalchemy.py @@ -30,8 +30,17 @@ def test_sqlalchemy_compilation(): Column('col1', STRING), Column('col2', TINYINT), Column('col3', INT), - Column('col4', DOUBLE)) + Column('col4', DOUBLE), + impala_partition_by='HASH PARTITIONS 16', + impala_stored_as='KUDU', + impala_table_properties={ + 'kudu.table_name': 'my_kudu_table', + 'kudu.master_addresses': 'kudu-master.example.com:7051' + }) observed = str(CreateTable(mytable, bind=engine)) expected = ('\nCREATE TABLE mytable (\n\tcol1 STRING, \n\tcol2 TINYINT, ' - '\n\tcol3 INT, \n\tcol4 DOUBLE\n)\n\n') + '\n\tcol3 INT, \n\tcol4 DOUBLE\n)' + '\nPARTITION BY HASH PARTITIONS 16\nSTORED AS KUDU\n' + "TBLPROPERTIES ('kudu.table_name' = 'my_kudu_table', " + "'kudu.master_addresses' = 'kudu-master.example.com:7051')\n\n") assert expected == observed