Skip to content

Commit

Permalink
Expose snowflake query_id in snowflake hook and operator, support mul…
Browse files Browse the repository at this point in the history
…tiple statements in sql string (apache#15533)
  • Loading branch information
mobuchowski authored Apr 30, 2021
1 parent 5e79b1e commit c6be8b1
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 1 deletion.
9 changes: 9 additions & 0 deletions airflow/providers/snowflake/example_dags/example_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
SQL_INSERT_STATEMENT = f"INSERT INTO {SNOWFLAKE_SAMPLE_TABLE} VALUES ('name', %(id)s)"
SQL_LIST = [SQL_INSERT_STATEMENT % {"id": n} for n in range(0, 10)]
SQL_MULTIPLE_STMTS = "; ".join(SQL_LIST)
SNOWFLAKE_SLACK_SQL = f"SELECT name, id FROM {SNOWFLAKE_SAMPLE_TABLE} LIMIT 10;"
SNOWFLAKE_SLACK_MESSAGE = (
"Results in an ASCII table:\n```{{ results_df | tabulate(tablefmt='pretty', headers='keys') }}```"
Expand Down Expand Up @@ -86,6 +87,13 @@
task_id='snowflake_op_sql_list', dag=dag, snowflake_conn_id=SNOWFLAKE_CONN_ID, sql=SQL_LIST
)

snowflake_op_sql_multiple_stmts = SnowflakeOperator(
task_id='snowflake_op_sql_multiple_stmts',
dag=dag,
snowflake_conn_id=SNOWFLAKE_CONN_ID,
sql=SQL_MULTIPLE_STMTS,
)

snowflake_op_template_file = SnowflakeOperator(
task_id='snowflake_op_template_file',
dag=dag,
Expand Down Expand Up @@ -130,6 +138,7 @@
snowflake_op_sql_list,
snowflake_op_template_file,
copy_into_table,
snowflake_op_sql_multiple_stmts,
]
>> slack_report
)
53 changes: 52 additions & 1 deletion airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional, Tuple
from contextlib import closing
from typing import Any, Dict, Optional, Tuple, Union

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(self, *args, **kwargs) -> None:
self.schema = kwargs.pop("schema", None)
self.authenticator = kwargs.pop("authenticator", None)
self.session_parameters = kwargs.pop("session_parameters", None)
self.query_ids = []

def _get_conn_params(self) -> Dict[str, Optional[str]]:
"""
Expand Down Expand Up @@ -245,3 +247,52 @@ def set_autocommit(self, conn, autocommit: Any) -> None:

def get_autocommit(self, conn):
return getattr(conn, 'autocommit_mode', False)

def run(self, sql: Union[str, list], autocommit: bool = False, parameters: Optional[dict] = None):
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
sequentially
:param sql: the sql string to be executed with possibly multiple statements,
or a list of sql statements to execute
:type sql: str or list
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
:type autocommit: bool
:param parameters: The parameters to render the SQL query with.
:type parameters: dict or iterable
"""
self.query_ids = []

with self.get_conn() as conn:
conn = self.get_conn()
self.set_autocommit(conn, autocommit)

if isinstance(sql, str):
cursors = conn.execute_string(sql, return_cursors=True)
for cur in cursors:
self.query_ids.append(cur.sfqid)

self.log.info("Rows affected: %s", cur.rowcount)
self.log.info("Snowflake query id: %s", cur.sfqid)
cur.close()

elif isinstance(sql, list):
self.log.debug("Executing %d statements against Snowflake DB", len(sql))
with closing(conn.cursor()) as cur:
for sql_statement in sql:

self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if parameters:
cur.execute(sql_statement, parameters)
else:
cur.execute(sql_statement)
self.log.info("Rows affected: %s", cur.rowcount)
self.log.info("Snowflake query id: %s", cur.sfqid)
self.query_ids.append(cur.sfqid)

# If autocommit was set to False for db that supports autocommit,
# or if db does not supports autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()
2 changes: 2 additions & 0 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids = []

def get_hook(self) -> SnowflakeHook:
"""
Expand All @@ -120,3 +121,4 @@ def execute(self, context: Any) -> None:
self.log.info('Executing: %s', self.sql)
hook = self.get_hook()
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
self.query_ids = hook.query_ids
19 changes: 19 additions & 0 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ def setUp(self):
super().setUp()

self.cur = mock.MagicMock()
self.cur2 = mock.MagicMock()

self.cur.sfqid = 'uuid'
self.cur2.sfqid = 'uuid2'

self.conn = conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
self.conn.execute_string.return_value = [self.cur, self.cur2]

self.conn.login = 'user'
self.conn.password = 'pw'
Expand Down Expand Up @@ -89,6 +95,19 @@ def test_get_uri(self):
)
assert uri_shouldbe == self.db_hook.get_uri()

def test_single_element_list_calls_execute(self):
self.db_hook.run(['select * from table'])
self.cur.execute.assert_called()
assert self.db_hook.query_ids == ['uuid']

def test_passed_string_calls_execute_string(self):
self.db_hook.run('select * from table; select * from table2')

assert self.db_hook.query_ids == ['uuid', 'uuid2']
self.conn.execute_string.assert_called()
self.cur.close.assert_called()
self.cur2.close.assert_called()

def test_get_conn_params(self):
conn_params_shouldbe = {
'user': 'user',
Expand Down

0 comments on commit c6be8b1

Please sign in to comment.