Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigQuery keyfile_dict profile mapping use env vars for sensitive fields #471

Merged
merged 4 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Any
import json
from cosmos.exceptions import CosmosValueError

from cosmos.profiles.base import BaseProfileMapping

Expand All @@ -23,6 +24,8 @@ class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
"keyfile_json",
]

secret_fields = ["private_key_id", "private_key"]

airflow_param_mapping = {
"project": "extra.project",
# multiple options for dataset because of older Airflow versions
Expand All @@ -31,6 +34,8 @@ class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
"keyfile_json": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
}

_env_vars: dict[str, str] = {}

@property
def profile(self) -> dict[str, Any | None]:
"""
Expand All @@ -48,12 +53,27 @@ def profile(self) -> dict[str, Any | None]:

def transform_keyfile_json(self, keyfile_json: str | dict[str, str]) -> dict[str, str]:
"""
Transforms the keyfile_json param to a dict if it is a string.
Transforms the keyfile_json param to a dict if it is a string, and sets environment
variables for the service account json secret fields.
"""
if isinstance(keyfile_json, dict):
return keyfile_json
keyfile_json = json.loads(keyfile_json)
if isinstance(keyfile_json, dict):
return keyfile_json
keyfile_json_dict = keyfile_json
else:
raise ValueError("keyfile_json cannot be loaded as a dict.")
keyfile_json_dict = json.loads(keyfile_json)
if not isinstance(keyfile_json_dict, dict):
raise CosmosValueError("keyfile_json cannot be loaded as a dict.")

for field in self.secret_fields:
value = keyfile_json_dict.get(field)
if value is None:
raise CosmosValueError(f"Could not find a value in service account json field: {field}.")
tatiana marked this conversation as resolved.
Show resolved Hide resolved
env_var_name = self.get_env_var_name(field)
self._env_vars[env_var_name] = value
keyfile_json_dict[field] = self.get_env_var_format(field)

return keyfile_json_dict

@property
def env_vars(self) -> dict[str, str]:
"Returns a dictionary of environment variables that should be set based on self.secret_fields."
return self._env_vars
57 changes: 54 additions & 3 deletions tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@

import pytest
from airflow.models.connection import Connection
from cosmos.exceptions import CosmosValueError

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping

sample_keyfile_dict = {
"type": "service_account",
"private_key_id": "my_private_key_id",
"private_key": "my_private_key",
}

@pytest.fixture(params=[{"key": "value"}, '{"key": "value"}'])

@pytest.fixture(params=[sample_keyfile_dict, json.dumps(sample_keyfile_dict)])
def mock_bigquery_conn_with_dict(request): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
Expand Down Expand Up @@ -43,7 +50,7 @@ def test_connection_claiming_succeeds(mock_bigquery_conn_with_dict: Connection):

def test_connection_claiming_fails(mock_bigquery_conn_with_dict: Connection):
# Remove the `dataset` key, which is mandatory
mock_bigquery_conn_with_dict.extra = json.dumps({"project": "my_project", "keyfile_dict": {"key": "value"}})
mock_bigquery_conn_with_dict.extra = json.dumps({"project": "my_project", "keyfile_dict": sample_keyfile_dict})
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert not profile_mapping.can_claim_connection()

Expand All @@ -56,6 +63,50 @@ def test_profile(mock_bigquery_conn_with_dict: Connection):
"project": "my_project",
"dataset": "my_dataset",
"threads": 1,
"keyfile_json": {"key": "value"},
"keyfile_json": {
"type": "service_account",
"private_key_id": "{{ env_var('COSMOS_CONN_GOOGLE_CLOUD_PLATFORM_PRIVATE_KEY_ID') }}",
"private_key": "{{ env_var('COSMOS_CONN_GOOGLE_CLOUD_PLATFORM_PRIVATE_KEY') }}",
},
}
assert profile_mapping.profile == expected


def test_profile_env_vars(mock_bigquery_conn_with_dict: Connection):
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_bigquery_conn_with_dict.conn_id,
{"dataset": "my_dataset"},
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_GOOGLE_CLOUD_PLATFORM_PRIVATE_KEY_ID": "my_private_key_id",
"COSMOS_CONN_GOOGLE_CLOUD_PLATFORM_PRIVATE_KEY": "my_private_key",
}


def test_transform_keyfile_json_missing_dict():
"""
Tests that a cosmos error is raised if the keyfile_json cannot be loaded as a dict.
"""
keyfile_json = '["value"]'
expected_cosmos_error = "keyfile_json cannot be loaded as a dict."

profile_mapping = GoogleCloudServiceAccountDictProfileMapping("", {})
with pytest.raises(CosmosValueError, match=expected_cosmos_error):
profile_mapping.transform_keyfile_json(keyfile_json)


@pytest.mark.parametrize("missing_secret_key", ["private_key_id", "private_key"])
def test_transform_keyfile_json_missing_secret_key(missing_secret_key: str):
"""
Tests that a cosmos error is raised if the keyfile_json is missing a secret key.
"""
keyfile_json = {k: v for k, v in sample_keyfile_dict.items() if k != missing_secret_key}
expected_cosmos_error = f"Could not find a value in service account json field: {missing_secret_key}."

profile_mapping = GoogleCloudServiceAccountDictProfileMapping("", {})

with pytest.raises(CosmosValueError, match=expected_cosmos_error):
profile_mapping.transform_keyfile_json(keyfile_json)
Loading