Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to GCP connections that define keyfile_dict instead of keyfile #352

Merged
merged 9 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,25 @@ jobs:
matrix:
python-version: ["3.10"]
airflow-version: ["2.6"]
if: >-
github.event_name == 'push' ||
(
github.event_name == 'pull_request' &&
github.event.pull_request.head.repo.fork == false
) ||
(
github.event_name == 'pull_request_target' &&
contains(github.event.pull_request.labels.*.name, 'safe')
)
steps:
- uses: actions/checkout@v3
if: github.event_name != 'pull_request_target'

- name: Checkout pull/${{ github.event.number }}
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
if: github.event_name == 'pull_request_target'

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
1 change: 0 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def run_command(
output_encoding=self.output_encoding,
cwd=tmp_project_dir,
)

self.exception_handling(result)
self.store_compiled_sql(tmp_project_dir, context)
if self.callback:
Expand Down
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .base import BaseProfileMapping
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
Expand All @@ -21,6 +22,7 @@

profile_mappings: list[Type[BaseProfileMapping]] = [
GoogleCloudServiceAccountFileProfileMapping,
GoogleCloudServiceAccountDictProfileMapping,
DatabricksTokenProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
Expand Down
14 changes: 9 additions & 5 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""
from __future__ import annotations

import json
from abc import ABC, abstractmethod

from logging import getLogger
from typing import Any

Expand Down Expand Up @@ -42,18 +42,21 @@
if self.conn.conn_type != self.airflow_connection_type:
return False

logger.info(dir(self.conn))
logger.info(self.conn.__dict__)

for field in self.required_fields:
try:
if not getattr(self, field):
logger.info(
"Not using mapping %s because %s is not set",
"1 Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
return False
except AttributeError:
logger.info(
"Not using mapping %s because %s is not set",
"2 Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
Expand All @@ -77,7 +80,9 @@
for field in self.secret_fields:
env_var_name = self.get_env_var_name(field)
value = self.get_dbt_value(field)
if value is not None:
if isinstance(value, dict):
env_vars[env_var_name] = json.dumps(value)

Check warning on line 84 in cosmos/profiles/base.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/base.py#L84

Added line #L84 was not covered by tests
elif value is not None:
env_vars[env_var_name] = str(value)

return env_vars
Expand All @@ -97,7 +102,6 @@
"outputs": {target_name: profile_vars},
}
}

return str(yaml.dump(profile_contents, indent=4))

def get_dbt_value(self, name: str) -> Any:
Expand Down
6 changes: 5 additions & 1 deletion cosmos/profiles/bigquery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"BigQuery Airflow connection -> dbt profile mappings"

from .service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping

__all__ = ["GoogleCloudServiceAccountFileProfileMapping"]
__all__ = [
"GoogleCloudServiceAccountFileProfileMapping",
"GoogleCloudServiceAccountDictProfileMapping",
]
49 changes: 49 additions & 0 deletions cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"Maps Airflow GCP connections to dbt BigQuery profiles if they use a service account keyfile dict/json."
from __future__ import annotations

from typing import Any

from cosmos.profiles.base import BaseProfileMapping


class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
"""
Maps Airflow GCP connections to dbt BigQuery profiles if they use a service account keyfile dict/json.

https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#service-account-file
https://airflow.apache.org/docs/apache-airflow-providers-google/stable/connections/gcp.html
"""

airflow_connection_type: str = "google_cloud_platform"

required_fields = [
"project",
"dataset",
"keyfile_dict",
]

airflow_param_mapping = {
"project": "extra.project",
# multiple options for dataset because of older Airflow versions
"dataset": ["extra.dataset", "dataset"],
# multiple options for keyfile_dict param name because of older Airflow versions
"keyfile_dict": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
}

@property
def profile(self) -> dict[str, Any | None]:
"""
Generates a GCP profile.
Even though the Airflow connection contains hard-coded Service account credentials,
we generate a temporary file and the DBT profile uses it.
"""
# keyfile_path = self.dump_credentials_to_disk()
return {

Check warning on line 41 in cosmos/profiles/bigquery/service_account_keyfile_dict.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/bigquery/service_account_keyfile_dict.py#L41

Added line #L41 was not covered by tests
"type": "bigquery",
"method": "service-account-json",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile_json": self.keyfile_dict,
tatiana marked this conversation as resolved.
Show resolved Hide resolved
**self.profile_args,
}
8 changes: 8 additions & 0 deletions docs/dbt/connections-profiles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ Service Account File
:members:


Service Account Dict
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: cosmos.profiles.bigquery.GoogleCloudServiceAccountDictProfileMapping
:undoc-members:
:members:


Databricks
----------

Expand Down
48 changes: 48 additions & 0 deletions tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_profile_mapping
from cosmos.profiles.bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping


@pytest.fixture()
def mock_bigquery_conn_with_dict(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"dataset": "my_dataset",
"keyfile_dict": {"key": "value"},
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_bigquery_mapping_selected(mock_bigquery_conn_with_dict: Connection):
profile_mapping = get_profile_mapping(
mock_bigquery_conn_with_dict.conn_id,
{"dataset": "my_dataset"},
)
assert isinstance(profile_mapping, GoogleCloudServiceAccountDictProfileMapping)


def test_connection_claiming_succeeds(mock_bigquery_conn_with_dict: Connection):
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert profile_mapping.can_claim_connection()


def test_connection_claiming_fails(mock_bigquery_conn_with_dict: Connection):
# Remove the `dataset` key, which is mandatory
mock_bigquery_conn_with_dict.extra = json.dumps({"project": "my_project", "keyfile_dict": {"key": "value"}})
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert not profile_mapping.can_claim_connection()