Skip to content

Commit

Permalink
Fix getting temporary AWS credentials with assume_role (astronomer#1081)
Browse files Browse the repository at this point in the history
When Airflow is getting temporary AWS credentials by assuming role with
`role_arn` as only `Connection` parameter, this cause task to fail due
to missing credentials. This is due to the latest changes related to
profile caching. The `env_vars` are accessed before `profile` which, in
this case, means required values are not populated yet.
  • Loading branch information
piotrkubicki authored and arojasb3 committed Jul 14, 2024
1 parent 25ffd29 commit 50b8fe1
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 7 deletions.
8 changes: 2 additions & 6 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,21 +287,17 @@ def ensure_profile(
if self.profiles_yml_filepath:
logger.info("Using user-supplied profiles.yml at %s", self.profiles_yml_filepath)
yield Path(self.profiles_yml_filepath), {}

elif self.profile_mapping:
if use_mock_values:
env_vars = {}
else:
env_vars = self.profile_mapping.env_vars

if is_profile_cache_enabled():
logger.info("Profile caching is enable.")
cached_profile_path = self._get_profile_path(use_mock_values)
env_vars = {} if use_mock_values else self.profile_mapping.env_vars
yield cached_profile_path, env_vars
else:
profile_contents = self.profile_mapping.get_profile_file_contents(
profile_name=self.profile_name, target_name=self.target_name, use_mock_values=use_mock_values
)
env_vars = {} if use_mock_values else self.profile_mapping.env_vars

if desired_profile_path:
logger.info(
Expand Down
31 changes: 31 additions & 0 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,37 @@ def test_load_via_dbt_ls_project_config_env_vars(
assert mock_popen.call_args.kwargs["env"]["MY_ENV_VAR"] == "my_value"


@patch("cosmos.dbt.graph.DbtGraph.should_use_dbt_ls_cache", return_value=False)
@patch("cosmos.config.is_profile_cache_enabled", return_value=False)
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
@patch("cosmos.config.RenderConfig.validate_dbt_command")
def test_profile_created_correctly_with_profile_mapping(
mock_validate,
mock_update_nodes,
mock_popen,
mock_enable_profile_cache,
mock_enable_cache,
tmp_dbt_project_dir,
postgres_profile_config,
):
"""Tests that the temporary profile is created without errors."""
mock_popen().communicate.return_value = ("", "")
mock_popen().returncode = 0
project_config = ProjectConfig(env_vars={})
render_config = RenderConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME)
profile_config = postgres_profile_config
execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME)
dbt_graph = DbtGraph(
project=project_config,
render_config=render_config,
execution_config=execution_config,
profile_config=profile_config,
)

assert dbt_graph.load_via_dbt_ls() == None


@patch("cosmos.dbt.graph.DbtGraph.should_use_dbt_ls_cache", return_value=False)
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
Expand Down
64 changes: 63 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from unittest.mock import patch
from unittest.mock import Mock, PropertyMock, call, patch

import pytest

from cosmos.config import CosmosConfigException, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import ExecutionMode, InvocationMode
from cosmos.exceptions import CosmosValueError
from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping
from cosmos.profiles.postgres.user_pass import PostgresUserPasswordProfileMapping

DBT_PROJECTS_ROOT_DIR = Path(__file__).parent / "sample/"
Expand Down Expand Up @@ -142,6 +143,67 @@ def test_profile_config_validate_profiles_yml():
assert err_info.value.args[0] == "The file /tmp/no-exists does not exist."


@patch("cosmos.config.is_profile_cache_enabled", return_value=False)
@patch("cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping.env_vars", new_callable=PropertyMock)
@patch("cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping.get_profile_file_contents")
@patch("cosmos.config.Path")
def test_profile_config_ensure_profile_without_caching_calls_get_profile_file_content_before_env_vars(
mock_path, mock_get_profile_file_contents, mock_env_vars, mock_cache
):
"""
The `env_vars` should not be called if profile file is not populated.
"""
profile_mapping = AthenaAccessKeyProfileMapping(conn_id="test", profile_args={})
profile_config = ProfileConfig(profile_name="test", target_name="test", profile_mapping=profile_mapping)
mock_manager = Mock()
mock_manager.attach_mock(mock_get_profile_file_contents, "get_profile_file_contents")
mock_manager.attach_mock(mock_env_vars, "env_vars")

with profile_config.ensure_profile(desired_profile_path=mock_path):
mock_get_profile_file_contents.assert_called_once()
mock_env_vars.assert_called_once()
expected_calls = [
call.get_profile_file_contents(profile_name="test", target_name="test", use_mock_values=False),
call.env_vars,
]
mock_manager.assert_has_calls(expected_calls, any_order=False)


@patch("cosmos.config.create_cache_profile")
@patch("cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping.version")
@patch("cosmos.config.get_cached_profile", return_value=None)
@patch("cosmos.config.is_profile_cache_enabled", return_value=True)
@patch("cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping.env_vars", new_callable=PropertyMock)
@patch("cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping.get_profile_file_contents")
@patch("cosmos.config.Path")
def test_profile_config_ensure_profile_with_caching_calls_get_profile_file_content_before_env_vars(
mock_path,
mock_get_profile_file_contents,
mock_env_vars,
mock_cache,
mock_get_cached_profile,
mock_version,
mock_create_cache_profile,
):
"""
The `env_vars` should not be called if profile file is not populated.
"""
profile_mapping = AthenaAccessKeyProfileMapping(conn_id="test", profile_args={})
profile_config = ProfileConfig(profile_name="test", target_name="test", profile_mapping=profile_mapping)
mock_manager = Mock()
mock_manager.attach_mock(mock_get_profile_file_contents, "get_profile_file_contents")
mock_manager.attach_mock(mock_env_vars, "env_vars")

with profile_config.ensure_profile(desired_profile_path=mock_path):
mock_get_profile_file_contents.assert_called_once()
mock_env_vars.assert_called_once()
expected_calls = [
call.get_profile_file_contents(profile_name="test", target_name="test", use_mock_values=False),
call.env_vars,
]
mock_manager.assert_has_calls(expected_calls, any_order=False)


@patch("cosmos.config.shutil.which", return_value=None)
def test_render_config_without_dbt_cmd(mock_which):
render_config = RenderConfig()
Expand Down

0 comments on commit 50b8fe1

Please sign in to comment.