From 0cf38bcd192d13048515f45f0f166a6222cb0435 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 11 Sep 2018 07:21:59 -0600 Subject: [PATCH 1/4] set quote policy on all cls.Relation.create() invocations --- dbt/adapters/default/impl.py | 30 ++++++++++++++++++++++++------ dbt/adapters/postgres/impl.py | 6 +++++- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 7bae81e3e1f..fdcf31f1ee7 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -131,6 +131,14 @@ def cancel_connection(cls, project_cfg, connection): raise dbt.exceptions.NotImplementedException( '`cancel_connection` is not implemented for this adapter!') + @classmethod + def _quote_policy(cls, project_cfg): + quoting_cfg = project_cfg.get('quoting', {}) + for quote_key in ('dtabase', 'schema', 'identifier'): + if quote_key not in quoting_cfg: + quoting_cfg[quote_key] = cls.DEFAULT_QUOTE + return quoting_cfg + ### # FUNCTIONS THAT SHOULD BE ABSTRACT ### @@ -154,7 +162,8 @@ def drop(cls, profile, project_cfg, schema, relation = cls.Relation.create( schema=schema, identifier=identifier, - type=relation_type) + type=relation_type, + quote_policy=cls._quote_policy(project_cfg)) return cls.drop_relation(profile, project_cfg, relation, model_name) @@ -175,7 +184,8 @@ def truncate(cls, profile, project_cfg, schema, table, model_name=None): relation = cls.Relation.create( schema=schema, identifier=table, - type='table') + type='table', + quote_policy=cls._quote_policy(project_cfg)) return cls.truncate_relation(profile, project_cfg, relation, model_name) @@ -190,12 +200,20 @@ def truncate_relation(cls, profile, project_cfg, @classmethod def rename(cls, profile, project_cfg, schema, from_name, to_name, model_name=None): + quote_policy = cls._quote_policy(project_cfg) + from_relation = cls.Relation.create( + schema=schema, + identifier=from_name, + quote_policy=quote_policy + ) + to_relation = cls.Relation.create( + identifier=to_name, + quote_policy=quote_policy + ) return cls.rename_relation( profile, project_cfg, - from_relation=cls.Relation.create( - schema=schema, identifier=from_name), - to_relation=cls.Relation.create( - identifier=to_name), + from_relation=from_relation, + to_relation=to_relation, model_name=model_name) @classmethod diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 24ed054f79c..873a48f1828 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -125,7 +125,11 @@ def alter_column_type(cls, profile, project, schema, table, column_name, 4. Rename the new column to existing column """ - relation = cls.Relation.create(schema=schema, identifier=table) + relation = cls.Relation.create( + schema=schema, + identifier=table, + quote_policy=cls._quote_policy(project) + ) opts = { "relation": relation, From 99a04e95127302b5ea4ee67016690674fdc05f2a Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 11 Sep 2018 09:44:53 -0600 Subject: [PATCH 2/4] add tests for quoting --- test/unit/test_postgres_adapter.py | 85 ++++++++++++++++++++++++++ test/unit/test_snowflake_adapter.py | 94 +++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 test/unit/test_snowflake_adapter.py diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index c5b6d05b39c..c6d0593fcf1 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -7,6 +7,7 @@ from dbt.adapters.postgres import PostgresAdapter from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger # noqa +from psycopg2 import extensions as psycopg2_extensions class TestPostgresAdapter(unittest.TestCase): @@ -79,3 +80,87 @@ def test_set_zero_keepalive(self, psycopg2): port=5432, connect_timeout=10) + +class TestConnectingPostgresAdapter(unittest.TestCase): + def setUp(self): + flags.STRICT_MODE = False + + self.profile = { + 'dbname': 'postgres', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5432, + 'schema': 'public' + } + + self.project = { + 'name': 'X', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/tmp/dbt/does-not-exist', + 'quoting': { + 'identifier': False, + 'schema': True, + } + } + + self.handle = mock.MagicMock(spec=psycopg2_extensions.connection) + self.cursor = self.handle.cursor.return_value + self.mock_execute = self.cursor.execute + self.patcher = mock.patch('dbt.adapters.postgres.impl.psycopg2') + self.psycopg2 = self.patcher.start() + + self.psycopg2.connect.return_value = self.handle + conn = PostgresAdapter.get_connection(self.profile) + + def tearDown(self): + # we want a unique self.handle every time. + PostgresAdapter.cleanup_connections() + self.patcher.stop() + + def test_quoting_on_drop_schema(self): + PostgresAdapter.drop_schema( + profile=self.profile, + project_cfg=self.project, + schema='test_schema' + ) + + self.mock_execute.assert_has_calls([ + mock.call('drop schema if exists "test_schema" cascade', None) + ]) + + def test_quoting_on_drop(self): + PostgresAdapter.drop( + profile=self.profile, + project_cfg=self.project, + schema='test_schema', + relation='test_table', + relation_type='table' + ) + self.mock_execute.assert_has_calls([ + mock.call('drop table if exists "test_schema".test_table cascade', None) + ]) + + def test_quoting_on_truncate(self): + PostgresAdapter.truncate( + profile=self.profile, + project_cfg=self.project, + schema='test_schema', + table='test_table' + ) + self.mock_execute.assert_has_calls([ + mock.call('truncate table "test_schema".test_table', None) + ]) + + def test_quoting_on_rename(self): + PostgresAdapter.rename( + profile=self.profile, + project_cfg=self.project, + schema='test_schema', + from_name='table_a', + to_name='table_b' + ) + self.mock_execute.assert_has_calls([ + mock.call('alter table "test_schema".table_a rename to table_b', None) + ]) diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py new file mode 100644 index 00000000000..71746e2e30b --- /dev/null +++ b/test/unit/test_snowflake_adapter.py @@ -0,0 +1,94 @@ +import mock +import unittest + +import dbt.flags as flags + +import dbt.adapters +from dbt.adapters.snowflake import SnowflakeAdapter +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger # noqa +from snowflake import connector as snowflake_connector + +class TestSnowflakeAdapter(unittest.TestCase): + def setUp(self): + flags.STRICT_MODE = False + + self.profile = { + 'dbname': 'postgres', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5432, + 'schema': 'public' + } + + self.project = { + 'name': 'X', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/tmp/dbt/does-not-exist', + 'quoting': { + 'identifier': False, + 'schema': True, + } + } + + self.handle = mock.MagicMock(spec=snowflake_connector.SnowflakeConnection) + self.cursor = self.handle.cursor.return_value + self.mock_execute = self.cursor.execute + self.patcher = mock.patch('dbt.adapters.snowflake.impl.snowflake.connector.connect') + self.snowflake = self.patcher.start() + + self.snowflake.return_value = self.handle + conn = SnowflakeAdapter.get_connection(self.profile) + + def tearDown(self): + # we want a unique self.handle every time. + SnowflakeAdapter.cleanup_connections() + self.patcher.stop() + + def test_quoting_on_drop_schema(self): + SnowflakeAdapter.drop_schema( + profile=self.profile, + project_cfg=self.project, + schema='test_schema' + ) + + self.mock_execute.assert_has_calls([ + mock.call('drop schema if exists "test_schema" cascade', None) + ]) + + def test_quoting_on_drop(self): + SnowflakeAdapter.drop( + profile=self.profile, + project_cfg=self.project, + schema='test_schema', + relation='test_table', + relation_type='table' + ) + self.mock_execute.assert_has_calls([ + mock.call('drop table if exists "test_schema".test_table cascade', None) + ]) + + def test_quoting_on_truncate(self): + SnowflakeAdapter.truncate( + profile=self.profile, + project_cfg=self.project, + schema='test_schema', + table='test_table' + ) + self.mock_execute.assert_has_calls([ + mock.call('truncate table "test_schema".test_table', None) + ]) + + def test_quoting_on_rename(self): + SnowflakeAdapter.rename( + profile=self.profile, + project_cfg=self.project, + schema='test_schema', + from_name='table_a', + to_name='table_b' + ) + self.mock_execute.assert_has_calls([ + mock.call('alter table "test_schema".table_a rename to table_b', None) + ]) From f2d153779cbbe1bbc03e0e1d3cb7c23b4dedb4b7 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 12 Sep 2018 07:26:13 -0600 Subject: [PATCH 3/4] fix typo --- dbt/adapters/default/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 12d85d034d3..e8b969046c7 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -152,7 +152,7 @@ def cancel_connection(cls, project_cfg, connection): @classmethod def _quote_policy(cls, project_cfg): quoting_cfg = project_cfg.get('quoting', {}) - for quote_key in ('dtabase', 'schema', 'identifier'): + for quote_key in ('database', 'schema', 'identifier'): if quote_key not in quoting_cfg: quoting_cfg[quote_key] = cls.DEFAULT_QUOTE return quoting_cfg From 7cbec9ee8f5c7cf53049d0b5527e274947a90ca8 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 12 Sep 2018 08:22:25 -0600 Subject: [PATCH 4/4] PR feedback --- dbt/adapters/default/impl.py | 19 +++++-------------- dbt/adapters/postgres/impl.py | 2 +- dbt/adapters/snowflake/impl.py | 2 -- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index e8b969046c7..0956dd517f2 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -43,8 +43,6 @@ def test(row): class DefaultAdapter(object): - DEFAULT_QUOTE = True - requires = {} context_functions = [ @@ -149,14 +147,6 @@ def cancel_connection(cls, project_cfg, connection): raise dbt.exceptions.NotImplementedException( '`cancel_connection` is not implemented for this adapter!') - @classmethod - def _quote_policy(cls, project_cfg): - quoting_cfg = project_cfg.get('quoting', {}) - for quote_key in ('database', 'schema', 'identifier'): - if quote_key not in quoting_cfg: - quoting_cfg[quote_key] = cls.DEFAULT_QUOTE - return quoting_cfg - ### # FUNCTIONS THAT SHOULD BE ABSTRACT ### @@ -181,7 +171,7 @@ def drop(cls, profile, project_cfg, schema, schema=schema, identifier=identifier, type=relation_type, - quote_policy=cls._quote_policy(project_cfg)) + quote_policy=project_cfg.get('quoting', {})) return cls.drop_relation(profile, project_cfg, relation, model_name) @@ -203,7 +193,7 @@ def truncate(cls, profile, project_cfg, schema, table, model_name=None): schema=schema, identifier=table, type='table', - quote_policy=cls._quote_policy(project_cfg)) + quote_policy=project_cfg.get('quoting', {})) return cls.truncate_relation(profile, project_cfg, relation, model_name) @@ -218,7 +208,7 @@ def truncate_relation(cls, profile, project_cfg, @classmethod def rename(cls, profile, project_cfg, schema, from_name, to_name, model_name=None): - quote_policy = cls._quote_policy(project_cfg) + quote_policy = project_cfg.get('quoting', {}) from_relation = cls.Relation.create( schema=schema, identifier=from_name, @@ -765,7 +755,8 @@ def _quote_as_configured(cls, project_cfg, identifier, quote_key): """This is the actual implementation of quote_as_configured, without the extra arguments needed for use inside materialization code. """ - if project_cfg.get('quoting', {}).get(quote_key, cls.DEFAULT_QUOTE): + default = cls.Relation.DEFAULTS['quote_policy'].get(quote_key) + if project_cfg.get('quoting', {}).get(quote_key, default): return cls.quote(identifier) else: return identifier diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 873a48f1828..065529bf416 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -128,7 +128,7 @@ def alter_column_type(cls, profile, project, schema, table, column_name, relation = cls.Relation.create( schema=schema, identifier=table, - quote_policy=cls._quote_policy(project) + quote_policy=project.get('quoting', {}) ) opts = { diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index af7ef8d22a6..1b9f1812429 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -18,8 +18,6 @@ class SnowflakeAdapter(PostgresAdapter): - DEFAULT_QUOTE = False - Relation = SnowflakeRelation @classmethod