Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add runtime per-model warehouse config on snowflake models (#1358) #1788

Merged
merged 2 commits into from
Sep 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from contextlib import contextmanager
from datetime import datetime
from typing import (
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List,
Mapping
)

import agate
Expand Down Expand Up @@ -1010,3 +1011,28 @@ def calculate_freshness(
'snapshotted_at': snapshotted_at,
'age': age,
}

def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
"""A hook for running some operation before the model materialization
runs. The hook can assume it has a connection available.

The only parameter is a configuration dictionary (the same one
available in the materialization context). It should be considered
read-only.

The pre-model hook may return anything as a context, which will be
passed to the post-model hook.
"""
pass

def post_model_hook(self, config: Mapping[str, Any], context: Any) -> None:
"""A hook for running some operation after the model materialization
runs. The hook can assume it has a connection available.

The first parameter is a configuration dictionary (the same one
available in the materialization context). It should be considered
read-only.

The second parameter is the value returned by pre_mdoel_hook.
"""
pass
12 changes: 6 additions & 6 deletions core/dbt/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,12 @@ def emit(self, record: logbook.LogRecord):


# we still need to use logging to suppress these or pytest captures them
logging.getLogger('botocore').setLevel(logging.INFO)
logging.getLogger('requests').setLevel(logging.INFO)
logging.getLogger('urllib3').setLevel(logging.INFO)
logging.getLogger('google').setLevel(logging.INFO)
logging.getLogger('snowflake.connector').setLevel(logging.INFO)
logging.getLogger('parsedatetime').setLevel(logging.INFO)
logging.getLogger('botocore').setLevel(logging.ERROR)
logging.getLogger('requests').setLevel(logging.ERROR)
logging.getLogger('urllib3').setLevel(logging.ERROR)
logging.getLogger('google').setLevel(logging.ERROR)
logging.getLogger('snowflake.connector').setLevel(logging.ERROR)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌 🙏

logging.getLogger('parsedatetime').setLevel(logging.ERROR)
# want to see werkzeug logs about errors
logging.getLogger('werkzeug').setLevel(logging.ERROR)

Expand Down
13 changes: 12 additions & 1 deletion core/dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,18 @@ def execute(self, model, manifest):
if materialization_macro is None:
missing_materialization(model, self.adapter.type())

result = materialization_macro.generator(context)()
if 'config' not in context:
raise InternalException(
'Invalid materialization context generated, missing config: {}'
.format(context)
)
context_config = context['config']

hook_ctx = self.adapter.pre_model_hook(context_config)
try:
result = materialization_macro.generator(context)()
finally:
self.adapter.post_model_hook(context_config, hook_ctx)

for relation in self._materialization_relations(result, model):
self.adapter.cache_added(relation.incorporate(dbt_created=True))
Expand Down
1 change: 1 addition & 0 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _rollback_handle(cls, connection):
"""On snowflake, rolling back the handle of an aborted session raises
an exception.
"""
logger.debug('initiating rollback')
try:
connection.handle.rollback()
except snowflake.connector.errors.ProgrammingError as e:
Expand Down
36 changes: 35 additions & 1 deletion plugins/snowflake/dbt/adapters/snowflake/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Mapping, Any, Optional

from dbt.adapters.sql import SQLAdapter
from dbt.adapters.snowflake import SnowflakeConnectionManager
from dbt.adapters.snowflake import SnowflakeRelation
from dbt.utils import filter_null_values
from dbt.exceptions import RuntimeException


class SnowflakeAdapter(SQLAdapter):
Expand All @@ -10,7 +13,7 @@ class SnowflakeAdapter(SQLAdapter):

AdapterSpecificConfigs = frozenset(
{"transient", "cluster_by", "automatic_clustering", "secure",
"copy_grants"}
"copy_grants", "warehouse"}
)

@classmethod
Expand Down Expand Up @@ -40,3 +43,34 @@ def _make_match_kwargs(self, database, schema, identifier):
return filter_null_values(
{"identifier": identifier, "schema": schema, "database": database}
)

def _get_warehouse(self) -> str:
_, table = self.execute(
'select current_warehouse() as warehouse',
fetch=True
)
if len(table) == 0 or len(table[0]) == 0:
# can this happen?
raise RuntimeException(
'Could not get current warehouse: no results'
)
return str(table[0][0])

def _use_warehouse(self, warehouse: str):
"""Use the given warehouse. Quotes are never applied."""
self.execute('use warehouse {}'.format(warehouse))

def pre_model_hook(self, config: Mapping[str, Any]) -> Optional[str]:
default_warehouse = self.config.credentials.warehouse
warehouse = config.get('warehouse', default_warehouse)
if warehouse == default_warehouse or warehouse is None:
return None
previous = self._get_warehouse()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like that this cost is only incurred if a non-default warehouse is specified!

self._use_warehouse(warehouse)
return previous

def post_model_hook(
self, config: Mapping[str, Any], context: Optional[str]
) -> None:
if context is not None:
self._use_warehouse(context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{{ config(materialized='table') }}
select 'DBT_TEST_ALT' as warehouse
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{{ config(warehouse='DBT_TEST_DOES_NOT_EXIST') }}
select current_warehouse() as warehouse
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{{ config(warehouse='DBT_TEST_ALT', materialized='table') }}
select current_warehouse() as warehouse
28 changes: 28 additions & 0 deletions test/integration/050_warehouse_test/test_warehouses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from test.integration.base import DBTIntegrationTest, use_profile
import os


class TestDebug(DBTIntegrationTest):
@property
def schema(self):
return 'dbt_warehouse_050'

@staticmethod
def dir(value):
return os.path.normpath(value)

@property
def models(self):
return self.dir('models')

@use_profile('snowflake')
def test_snowflake_override_ok(self):
self.run_dbt([
'run',
'--models', 'override_warehouse', 'expected_warehouse',
])
self.assertManyRelationsEqual([['OVERRIDE_WAREHOUSE'], ['EXPECTED_WAREHOUSE']])

@use_profile('snowflake')
def test_snowflake_override_noexist(self):
self.run_dbt(['run', '--models', 'invalid_warehouse'], expect_pass=False)
63 changes: 62 additions & 1 deletion test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from contextlib import contextmanager
from unittest import mock

import dbt.flags as flags
Expand Down Expand Up @@ -46,7 +47,8 @@ def setUp(self):
self.cursor = self.handle.cursor.return_value
self.mock_execute = self.cursor.execute
self.patcher = mock.patch(
'dbt.adapters.snowflake.connections.snowflake.connector.connect')
'dbt.adapters.snowflake.connections.snowflake.connector.connect'
)
self.snowflake = self.patcher.start()

self.load_patch = mock.patch('dbt.loader.make_parse_result')
Expand Down Expand Up @@ -133,6 +135,65 @@ def test_quoting_on_rename(self):
)
])

@contextmanager
def current_warehouse(self, response):
# there is probably some elegant way built into mock.patch to do this
fetchall_return = self.cursor.fetchall.return_value
execute_side_effect = self.mock_execute.side_effect

def execute_effect(sql, *args, **kwargs):
if sql == 'select current_warehouse() as warehouse':
self.cursor.description = [['name']]
self.cursor.fetchall.return_value = [[response]]
else:
self.cursor.description = None
self.cursor.fetchall.return_value = fetchall_return
return self.mock_execute.return_value

self.mock_execute.side_effect = execute_effect
try:
yield
finally:
self.cursor.fetchall.return_value = fetchall_return
self.mock_execute.side_effect = execute_side_effect

def _strip_transactions(self):
result = []
for call_args in self.mock_execute.call_args_list:
args, kwargs = tuple(call_args)
is_transactional = (
len(kwargs) == 0 and
len(args) == 2 and
args[1] is None and
args[0] in {'BEGIN', 'COMMIT'}
)
if not is_transactional:
result.append(call_args)
return result

def test_pre_post_hooks_warehouse(self):
with self.current_warehouse('warehouse'):
config = {'warehouse': 'other_warehouse'}
result = self.adapter.pre_model_hook(config)
self.assertIsNotNone(result)
calls = [
mock.call('select current_warehouse() as warehouse', None),
mock.call('use warehouse other_warehouse', None)
]
self.mock_execute.assert_has_calls(calls)
self.adapter.post_model_hook(config, result)
calls.append(mock.call('use warehouse warehouse', None))
self.mock_execute.assert_has_calls(calls)

def test_pre_post_hooks_no_warehouse(self):
with self.current_warehouse('warehouse'):
config = {}
result = self.adapter.pre_model_hook(config)
self.assertIsNone(result)
self.mock_execute.assert_not_called()
self.adapter.post_model_hook(config, result)
self.mock_execute.assert_not_called()

def test_cancel_open_connections_empty(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

Expand Down