Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWS Athena profile mapping #578

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Type


from .athena import AthenaAccessKeyProfileMapping
from .base import BaseProfileMapping
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
Expand All @@ -21,6 +22,7 @@
from .trino.ldap import TrinoLDAPProfileMapping

profile_mappings: list[Type[BaseProfileMapping]] = [
AthenaAccessKeyProfileMapping,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, should this have been included in the all export on line 60?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good spot.

I have an upcoming PR for the Athena profile mapping, so I'll probably squeeze the fix in there.

This shouldn't affect anyone, unless they're doing from cosmos.profiles import *

GoogleCloudServiceAccountFileProfileMapping,
GoogleCloudServiceAccountDictProfileMapping,
GoogleCloudOauthProfileMapping,
Expand Down
5 changes: 5 additions & 0 deletions cosmos/profiles/athena/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"Athena Airflow connection -> dbt profile mappings"

from .access_key import AthenaAccessKeyProfileMapping

__all__ = ["AthenaAccessKeyProfileMapping"]
59 changes: 59 additions & 0 deletions cosmos/profiles/athena/access_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key."
from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class AthenaAccessKeyProfileMapping(BaseProfileMapping):
"""
Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key.

https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup
https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html
"""

benjamin-awd marked this conversation as resolved.
Show resolved Hide resolved
airflow_connection_type: str = "aws"
dbt_profile_type: str = "athena"
is_community: bool = True

required_fields = [
"aws_access_key_id",
"aws_secret_access_key",
"database",
"region_name",
"s3_staging_dir",
"schema",
]
secret_fields = [
"aws_secret_access_key",
]
airflow_param_mapping = {
"aws_access_key_id": "login",
"aws_secret_access_key": "password",
"aws_profile_name": "extra.aws_profile_name",
"database": "extra.database",
"debug_query_state": "extra.debug_query_state",
"lf_tags_database": "extra.lf_tags_database",
"num_retries": "extra.num_retries",
"poll_interval": "extra.poll_interval",
"region_name": "extra.region_name",
"s3_data_dir": "extra.s3_data_dir",
"s3_data_naming": "extra.s3_data_naming",
"s3_staging_dir": "extra.s3_staging_dir",
"schema": "extra.schema",
"seed_s3_upload_args": "extra.seed_s3_upload_args",
"work_group": "extra.work_group",
}

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile = {
**self.mapped_params,
**self.profile_args,
# aws_secret_access_key should always get set as env var
"aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"),
}
return self.filter_null(profile)
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [

[project.optional-dependencies]
dbt-all = [
"dbt-athena",
"dbt-bigquery",
"dbt-databricks",
"dbt-exasol",
Expand All @@ -54,6 +55,9 @@ dbt-all = [
"dbt-snowflake",
"dbt-spark",
]
dbt-athena = [
"dbt-athena-community",
]
dbt-bigquery = [
"dbt-bigquery",
]
Expand Down
154 changes: 154 additions & 0 deletions tests/profiles/athena/test_athena_access_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"Tests for the Athena profile."

import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping


@pytest.fixture()
def mock_athena_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_athena_connection",
conn_type="aws",
login="my_aws_access_key_id",
password="my_aws_secret_key",
extra=json.dumps(
{
"database": "my_database",
"region_name": "my_region",
"s3_staging_dir": "s3://my_bucket/dbt/",
"schema": "my_schema",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_athena_connection_claiming() -> None:
"""
Tests that the Athena profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == aws
# and the following exist:
# - login
# - password
# - database
# - region_name
# - s3_staging_dir
# - schema
potential_values = {
"conn_type": "aws",
"login": "my_aws_access_key_id",
"password": "my_aws_secret_key",
"extra": json.dumps(
{
"database": "my_database",
"region_name": "my_region",
"s3_staging_dir": "s3://my_bucket/dbt/",
"schema": "my_schema",
}
),
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
# should raise an InvalidMappingException
profile_mapping = AthenaAccessKeyProfileMapping(conn, {})
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = AthenaAccessKeyProfileMapping(conn, {})
assert profile_mapping.can_claim_connection()


def test_athena_profile_mapping_selected(
mock_athena_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)
assert isinstance(profile_mapping, AthenaAccessKeyProfileMapping)


def test_athena_profile_args(
mock_athena_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"database": mock_athena_conn.extra_dejson.get("database"),
"region_name": mock_athena_conn.extra_dejson.get("region_name"),
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"),
"schema": mock_athena_conn.extra_dejson.get("schema"),
}


def test_athena_profile_args_overrides(
mock_athena_conn: Connection,
) -> None:
"""
Tests that you can override the profile values for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
profile_args={"schema": "my_custom_schema", "database": "my_custom_db"},
)
assert profile_mapping.profile_args == {
"schema": "my_custom_schema",
"database": "my_custom_db",
}

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"database": "my_custom_db",
"region_name": mock_athena_conn.extra_dejson.get("region_name"),
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"),
"schema": "my_custom_schema",
}


def test_athena_profile_env_vars(
mock_athena_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly for Athena.
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password,
}
Loading