Skip to content

Commit

Permalink
Add AWS Athena profile mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamin-awd committed Oct 6, 2023
1 parent 435a699 commit 0e08feb
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Type


from .athena import AthenaAccessKeyProfileMapping
from .base import BaseProfileMapping
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
Expand All @@ -21,6 +22,7 @@
from .trino.ldap import TrinoLDAPProfileMapping

profile_mappings: list[Type[BaseProfileMapping]] = [
AthenaAccessKeyProfileMapping,
GoogleCloudServiceAccountFileProfileMapping,
GoogleCloudServiceAccountDictProfileMapping,
GoogleCloudOauthProfileMapping,
Expand Down
5 changes: 5 additions & 0 deletions cosmos/profiles/athena/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"Athena Airflow connection -> dbt profile mappings"

from .access_key import AthenaAccessKeyProfileMapping

__all__ = ["AthenaAccessKeyProfileMapping"]
59 changes: 59 additions & 0 deletions cosmos/profiles/athena/access_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key."
from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class AthenaAccessKeyProfileMapping(BaseProfileMapping):
"""
Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key.
https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup
https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html
"""

airflow_connection_type: str = "aws"
dbt_profile_type: str = "athena"
is_community: bool = True

required_fields = [
"aws_access_key_id",
"aws_secret_access_key",
"database",
"region_name",
"s3_staging_dir",
"schema",
]
secret_fields = [
"aws_secret_access_key",
]
airflow_param_mapping = {
"aws_access_key_id": "login",
"aws_secret_access_key": "password",
"aws_profile_name": "extra.aws_profile_name",
"database": "extra.database",
"debug_query_state": "extra.debug_query_state",
"lf_tags_database": "extra.lf_tags_database",
"num_retries": "extra.num_retries",
"poll_interval": "extra.poll_interval",
"region_name": "extra.region_name",
"s3_data_dir": "extra.s3_data_dir",
"s3_data_naming": "extra.s3_data_naming",
"s3_staging_dir": "extra.s3_staging_dir",
"schema": "extra.schema",
"seed_s3_upload_args": "extra.seed_s3_upload_args",
"work_group": "extra.work_group",
}

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile = {
**self.mapped_params,
**self.profile_args,
# aws_secret_access_key should always get set as env var
"aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"),
}
return self.filter_null(profile)
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [

[project.optional-dependencies]
dbt-all = [
"dbt-athena",
"dbt-bigquery",
"dbt-databricks",
"dbt-exasol",
Expand All @@ -54,6 +55,9 @@ dbt-all = [
"dbt-snowflake",
"dbt-spark",
]
dbt-athena = [
"dbt-athena-community",
]
dbt-bigquery = [
"dbt-bigquery",
]
Expand Down
154 changes: 154 additions & 0 deletions tests/profiles/athena/test_athena_access_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"Tests for the Athena profile."

import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping


@pytest.fixture()
def mock_athena_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_athena_connection",
conn_type="aws",
login="my_aws_access_key_id",
password="my_aws_secret_key",
extra=json.dumps(
{
"database": "my_database",
"region_name": "my_region",
"s3_staging_dir": "s3://my_bucket/dbt/",
"schema": "my_schema",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_athena_connection_claiming() -> None:
"""
Tests that the Athena profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == aws
# and the following exist:
# - login
# - password
# - database
# - region_name
# - s3_staging_dir
# - schema
potential_values = {
"conn_type": "aws",
"login": "my_aws_access_key_id",
"password": "my_aws_secret_key",
"extra": json.dumps(
{
"database": "my_database",
"region_name": "my_region",
"s3_staging_dir": "s3://my_bucket/dbt/",
"schema": "my_schema",
}
),
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
# should raise an InvalidMappingException
profile_mapping = AthenaAccessKeyProfileMapping(conn, {})
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = AthenaAccessKeyProfileMapping(conn, {})
assert profile_mapping.can_claim_connection()


def test_athena_profile_mapping_selected(
mock_athena_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)
assert isinstance(profile_mapping, AthenaAccessKeyProfileMapping)


def test_athena_profile_args(
mock_athena_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"database": mock_athena_conn.extra_dejson.get("database"),
"region_name": mock_athena_conn.extra_dejson.get("region_name"),
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"),
"schema": mock_athena_conn.extra_dejson.get("schema"),
}


def test_athena_profile_args_overrides(
mock_athena_conn: Connection,
) -> None:
"""
Tests that you can override the profile values for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
profile_args={"schema": "my_custom_schema", "database": "my_custom_db"},
)
assert profile_mapping.profile_args == {
"schema": "my_custom_schema",
"database": "my_custom_db",
}

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"database": "my_custom_db",
"region_name": mock_athena_conn.extra_dejson.get("region_name"),
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"),
"schema": "my_custom_schema",
}


def test_athena_profile_env_vars(
mock_athena_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password,
}

0 comments on commit 0e08feb

Please sign in to comment.