diff --git a/dev_requirements.txt b/dev_requirements.txt index 6180d33c3..bbcdc9d6d 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -10,6 +10,6 @@ pytest-xdist>=2.1.0,<3 flaky>=3.5.3,<4 # Test requirements -git+https://github.com/fishtown-analytics/dbt-adapter-tests.git@feature/add-integration-test-tools +git+https://github.com/fishtown-analytics/dbt-adapter-tests.git@33872d1cc0f936677dae091c3e0b49771c280514 sasl==0.2.1 thrift_sasl==0.4.1 diff --git a/test/custom/base.py b/test/custom/base.py index ed34e878c..1b4c886bc 100644 --- a/test/custom/base.py +++ b/test/custom/base.py @@ -1,7 +1,70 @@ -from dbt_adapter_tests import DBTIntegrationTestBase, use_profile +import pytest +from functools import wraps +import os +from dbt_adapter_tests import DBTIntegrationTestBase + class DBTSparkIntegrationTest(DBTIntegrationTestBase): - + + def get_profile(self, adapter_type): + if adapter_type == 'apache_spark': + return self.apache_spark_profile() + elif adapter_type == 'databricks_cluster': + return self.databricks_cluster_profile() + elif adapter_type == 'databricks_sql_endpoint': + return self.databricks_sql_endpoint_profile() + else: + raise ValueError('invalid adapter type {}'.format(adapter_type)) + + @staticmethod + def _profile_from_test_name(test_name): + adapter_names = ('apache_spark', 'databricks_cluster', + 'databricks_sql_endpoint') + adapters_in_name = sum(x in test_name for x in adapter_names) + if adapters_in_name != 1: + raise ValueError( + 'test names must have exactly 1 profile choice embedded, {} has {}' + .format(test_name, adapters_in_name) + ) + + for adapter_name in adapter_names: + if adapter_name in test_name: + return adapter_name + + raise ValueError( + 'could not find adapter name in test name {}'.format(test_name) + ) + + def run_sql(self, query, fetch='None', kwargs=None, connection_name=None): + if connection_name is None: + connection_name = '__test' + + if query.strip() == "": + return + + sql = self.transform_sql(query, kwargs=kwargs) + + with self.get_connection(connection_name) as conn: + cursor = conn.handle.cursor() + try: + cursor.execute(sql) + if fetch == 'one': + return cursor.fetchall()[0] + elif fetch == 'all': + return cursor.fetchall() + else: + # we have to fetch. + cursor.fetchall() + except Exception as e: + conn.handle.rollback() + conn.transaction_open = False + print(sql) + print(e) + raise + else: + conn.handle.commit() + conn.transaction_open = False + def apache_spark_profile(self): return { 'config': { @@ -14,13 +77,13 @@ def apache_spark_profile(self): 'host': 'localhost', 'user': 'dbt', 'method': 'thrift', - 'port': '10000', - 'connect_retries': '5', - 'connect_timeout': '60', + 'port': 10000, + 'connect_retries': 5, + 'connect_timeout': 60, 'schema': self.unique_schema() }, + }, 'target': 'default2' - } } } @@ -40,11 +103,11 @@ def databricks_cluster_profile(self): 'port': 443, 'schema': self.unique_schema() }, + }, 'target': 'odbc' - } } } - + def databricks_sql_endpoint_profile(self): return { 'config': { @@ -61,7 +124,34 @@ def databricks_sql_endpoint_profile(self): 'port': 443, 'schema': self.unique_schema() }, + }, 'target': 'default2' - } } } + + +def use_profile(profile_name): + """A decorator to declare a test method as using a particular profile. + Handles both setting the nose attr and calling self.use_profile. + + Use like this: + + class TestSomething(DBIntegrationTest): + @use_profile('postgres') + def test_postgres_thing(self): + self.assertEqual(self.adapter_type, 'postgres') + + @use_profile('snowflake') + def test_snowflake_thing(self): + self.assertEqual(self.adapter_type, 'snowflake') + """ + def outer(wrapped): + @getattr(pytest.mark, 'profile_'+profile_name) + @wraps(wrapped) + def func(self, *args, **kwargs): + return wrapped(self, *args, **kwargs) + # sanity check at import time + assert DBTSparkIntegrationTest._profile_from_test_name( + wrapped.__name__) == profile_name + return func + return outer diff --git a/test/custom/conftest.py b/test/custom/conftest.py new file mode 100644 index 000000000..02248bae3 --- /dev/null +++ b/test/custom/conftest.py @@ -0,0 +1,10 @@ +def pytest_configure(config): + config.addinivalue_line( + "markers", "profile_databricks_cluster" + ) + config.addinivalue_line( + "markers", "profile_databricks_sql_endpoint" + ) + config.addinivalue_line( + "markers", "profile_apache_spark" + ) diff --git a/test/custom/incremental_strategies/test_incremental_strategies.py b/test/custom/incremental_strategies/test_incremental_strategies.py index 5880d2fcb..5ad7a3f79 100644 --- a/test/custom/incremental_strategies/test_incremental_strategies.py +++ b/test/custom/incremental_strategies/test_incremental_strategies.py @@ -1,4 +1,6 @@ -from test.custom.base import DBTSparkIntegrationTest +from test.custom.base import DBTSparkIntegrationTest, use_profile +import dbt.exceptions + class TestIncrementalStrategies(DBTSparkIntegrationTest): @property @@ -14,73 +16,80 @@ def run_and_test(self): self.run_dbt(["run"]) self.assertTablesEqual("default_append", "expected_append") + class TestDefaultAppend(TestIncrementalStrategies): @use_profile("apache_spark") def test_default_append_apache_spark(self): self.run_and_test() - + @use_profile("databricks_cluster") - def test_default_append_databricks(self): + def test_default_append_databricks_cluster(self): self.run_and_test() + class TestInsertOverwrite(TestIncrementalStrategies): @property def models(self): return "models_insert_overwrite" - + def run_and_test(self): self.run_dbt(["seed"]) self.run_dbt(["run"]) - self.assertTablesEqual("insert_overwrite_no_partitions", "expected_overwrite") - self.assertTablesEqual("insert_overwrite_partitions", "expected_upsert") - + self.assertTablesEqual( + "insert_overwrite_no_partitions", "expected_overwrite") + self.assertTablesEqual( + "insert_overwrite_partitions", "expected_upsert") + @use_profile("apache_spark") def test_insert_overwrite_apache_spark(self): self.run_and_test() - + @use_profile("databricks_cluster") - def test_insert_overwrite_databricks(self): + def test_insert_overwrite_databricks_cluster(self): self.run_and_test() + class TestDeltaStrategies(TestIncrementalStrategies): @property def models(self): return "models_delta" - + def run_and_test(self): self.run_dbt(["seed"]) self.run_dbt(["run"]) self.assertTablesEqual("append_delta", "expected_append") self.assertTablesEqual("merge_no_key", "expected_append") self.assertTablesEqual("merge_unique_key", "expected_upsert") - + @use_profile("databricks_cluster") - def test_delta_strategies_databricks(self): + def test_delta_strategies_databricks_cluster(self): self.run_and_test() + class TestBadStrategies(TestIncrementalStrategies): @property def models(self): return "models_insert_overwrite" - + def run_and_test(self): with self.assertRaises(dbt.exceptions.Exception) as exc: self.run_dbt(["compile"]) message = str(exc.exception) self.assertIn("Invalid file format provided", message) self.assertIn("Invalid incremental strategy provided", message) - + @use_profile("apache_spark") def test_bad_strategies_apache_spark(self): self.run_and_test() - + @use_profile("databricks_cluster") - def test_bad_strategies_databricks(self): + def test_bad_strategies_databricks_cluster(self): self.run_and_test() - + + class TestBadStrategyWithEndpoint(TestInsertOverwrite): @use_profile("databricks_sql_endpoint") - def run_and_test(self): + def test_bad_strategies_databricks_sql_endpoint(self): with self.assertRaises(dbt.exceptions.Exception) as exc: self.run_dbt(["compile"], "--target", "odbc-sql-endpoint") message = str(exc.exception)