Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snowflake private key #378

Merged
merged 9 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .postgres.user_pass import PostgresUserPasswordProfileMapping
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -24,6 +25,7 @@
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
ExasolUserPasswordProfileMapping,
TrinoLDAPProfileMapping,
Expand Down
3 changes: 2 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"Snowflake Airflow connection -> dbt profile mapping."

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping

__all__ = ["SnowflakeUserPasswordProfileMapping"]
__all__ = ["SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping"]
87 changes: 87 additions & 0 deletions cosmos/profiles/snowflake/user_privatekey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection

Check warning on line 10 in cosmos/profiles/snowflake/user_privatekey.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/snowflake/user_privatekey.py#L10

Added line #L10 was not covered by tests


class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

tatiana marked this conversation as resolved.
Show resolved Hide resolved
airflow_connection_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key_content",
]
secret_fields = [
"private_key_content",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key_content": "extra.private_key_content",
}

def __init__(self, conn: Connection, profile_args: dict[str, Any | None] | None = None) -> None:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.

This standardizes the keys to be 'account', 'database', etc.
"""
conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

self.conn = conn
self.profile_args = profile_args or {}
super().__init__(conn, profile_args)

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
"type": "snowflake",
"account": self.account,
"user": self.user,
"schema": self.schema,
"database": self.database,
"role": self.conn.extra_dejson.get("role"),
"warehouse": self.conn.extra_dejson.get("warehouse"),
**self.profile_args,
# private_key should always get set as env var
"private_key_content": self.get_env_var_format("private_key_content"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
8 changes: 8 additions & 0 deletions docs/dbt/connections-profiles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ Username and Password
:members:


Username and Private Key
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: cosmos.profiles.snowflake.SnowflakePrivateKeyPemProfileMapping
:undoc-members:
:members:


Spark
-----

Expand Down
4 changes: 2 additions & 2 deletions tests/profiles/snowflake/test_snowflake_user_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Tests for the Snowflake profile."
"Tests for the Snowflake user/password profile."

import json
from unittest.mock import patch
Expand All @@ -18,7 +18,7 @@ def mock_snowflake_conn(): # type: ignore
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_id="my_snowflake_password_connection",
conn_type="snowflake",
login="my_user",
password="my_password",
Expand Down
238 changes: 238 additions & 0 deletions tests/profiles/snowflake/test_snowflake_user_privatekey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
"Tests for the Snowflake user/private key profile."

import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_profile_mapping
from cosmos.profiles.snowflake import (
SnowflakePrivateKeyPemProfileMapping,
)


@pytest.fixture()
def mock_snowflake_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_pk_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the Snowflake profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == snowflake
# and the following exist:
# - user
# - private key
# - account
# - database
# - warehouse
# - schema
potential_values = {
"conn_type": "snowflake",
"login": "my_user",
"schema": "my_database",
"extra": json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

profile_mapping = SnowflakePrivateKeyPemProfileMapping(
conn,
)
assert not profile_mapping.can_claim_connection()

# test when we're missing the account
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the database
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the warehouse
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert isinstance(profile_mapping, SnowflakePrivateKeyPemProfileMapping)


def test_profile_args(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
)

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": mock_snowflake_conn.schema,
"account": mock_snowflake_conn.extra_dejson.get("account"),
"database": mock_snowflake_conn.extra_dejson.get("database"),
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_args_overrides(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
profile_args={"database": "my_db_override"},
)
assert profile_mapping.profile_args == {
"database": "my_db_override",
}

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": mock_snowflake_conn.schema,
"account": mock_snowflake_conn.extra_dejson.get("account"),
"database": "my_db_override",
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_env_vars(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT": mock_snowflake_conn.extra_dejson.get("private_key_content"),
}


def test_old_snowflake_format() -> None:
"""
Tests that the old format still works.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_content": "my_private_key",
}
),
)

profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
"warehouse": conn.extra_dejson.get("warehouse"),
}


def test_appends_region() -> None:
"""
Tests that region is appended to account if it doesn't already exist.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"account": "my_account",
"region": "my_region",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
)

profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": conn.schema,
"account": f"{conn.extra_dejson.get('account')}.{conn.extra_dejson.get('region')}",
"database": conn.extra_dejson.get("database"),
"warehouse": conn.extra_dejson.get("warehouse"),
}