diff --git a/airflow/providers/amazon/aws/secrets/ssm.py b/airflow/providers/amazon/aws/secrets/ssm.py index 61bf592ed84ce3..0d62e8074f8846 100644 --- a/airflow/providers/amazon/aws/secrets/ssm.py +++ b/airflow/providers/amazon/aws/secrets/ssm.py @@ -24,9 +24,10 @@ from airflow.models import Connection from airflow.secrets import CONN_ENV_PREFIX, BaseSecretsBackend +from airflow.utils.log.logging_mixin import LoggingMixin -class AwsSsmSecretsBackend(BaseSecretsBackend): +class AwsSsmSecretsBackend(BaseSecretsBackend, LoggingMixin): """ Retrieves Connection object from AWS SSM Parameter Store @@ -66,26 +67,38 @@ def build_ssm_path(self, conn_id: str): param_path = self.prefix + "/" + param_name return param_path - def get_conn_uri(self, conn_id: str): + def get_conn_uri(self, conn_id: str) -> Optional[str]: """ Get param value :param conn_id: connection id + :type conn_id: str """ session = boto3.Session(profile_name=self.profile_name) client = session.client("ssm") - response = client.get_parameter( - Name=self.build_ssm_path(conn_id=conn_id), WithDecryption=True - ) - value = response["Parameter"]["Value"] - return value + ssm_path = self.build_ssm_path(conn_id=conn_id) + try: + response = client.get_parameter( + Name=ssm_path, WithDecryption=False + ) + value = response["Parameter"]["Value"] + return value + except client.exceptions.ParameterNotFound: + self.log.info( + "An error occurred (ParameterNotFound) when calling the GetParameter operation: " + "Parameter %s not found.", ssm_path + ) + return None def get_connections(self, conn_id: str) -> List[Connection]: """ Create connection object. :param conn_id: connection id + :type conn_id: str """ conn_uri = self.get_conn_uri(conn_id=conn_id) + if not conn_uri: + return [] conn = Connection(conn_id=conn_id, uri=conn_uri) return [conn] diff --git a/airflow/secrets/__init__.py b/airflow/secrets/__init__.py index 69aa79cb083d0a..10368ffc101fb7 100644 --- a/airflow/secrets/__init__.py +++ b/airflow/secrets/__init__.py @@ -66,7 +66,7 @@ def get_connections(conn_id: str) -> List[Connection]: :param conn_id: connection id :return: array of connections """ - for secrets_backend in secrets_backend_list: + for secrets_backend in ensure_secrets_loaded(): conn_list = secrets_backend.get_connections(conn_id=conn_id) if conn_list: return list(conn_list) @@ -100,4 +100,15 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]: return backend_list +def ensure_secrets_loaded() -> List[BaseSecretsBackend]: + """ + Ensure that all secrets backends are loaded. + If the secrets_backend_list contains only 2 default backends, reload it. + """ + # Check if the secrets_backend_list contains only 2 default backends + if len(secrets_backend_list) == 2: + return initialize_secrets_backends() + return secrets_backend_list + + secrets_backend_list = initialize_secrets_backends() diff --git a/tests/providers/amazon/aws/secrets/test_ssm.py b/tests/providers/amazon/aws/secrets/test_ssm.py index 1a3927a8b45647..2874260e4ba33e 100644 --- a/tests/providers/amazon/aws/secrets/test_ssm.py +++ b/tests/providers/amazon/aws/secrets/test_ssm.py @@ -15,14 +15,39 @@ # specific language governing permissions and limitations # under the License. -from unittest import mock +from unittest import TestCase, mock + +from moto import mock_ssm from airflow.providers.amazon.aws.secrets.ssm import AwsSsmSecretsBackend +from airflow.secrets import get_connections + + +class TestSsmSecrets(TestCase): + @mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri") + def test_aws_ssm_get_connections(self, mock_get_uri): + mock_get_uri.return_value = "scheme://user:pass@host:100" + conn_list = AwsSsmSecretsBackend().get_connections("fake_conn") + conn = conn_list[0] + assert conn.host == 'host' + + @mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow', + }) + @mock_ssm + def test_get_conn_uri_non_existent_key(self): + """ + Test that if the key with connection ID is not present in SSM, + AwsSsmSecretsBackend.get_connections should return None and fallback to the + environment variable if it is set + """ + conn_id = "test_mysql" + test_client = AwsSsmSecretsBackend() + self.assertIsNone(test_client.get_conn_uri(conn_id=conn_id)) + self.assertEqual([], test_client.get_connections(conn_id=conn_id)) -@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri") -def test_aws_ssm_get_connections(mock_get_uri): - mock_get_uri.side_effect = ["scheme://user:pass@host:100"] - conn_list = AwsSsmSecretsBackend().get_connections("fake_conn") - conn = conn_list[0] - assert conn.host == 'host' + # Fallback to connection defined in Environment Variable + self.assertEqual( + "mysql://airflow:airflow@host:5432/airflow", + get_connections(conn_id="test_mysql")[0].get_uri()) diff --git a/tests/secrets/test_secrets.py b/tests/secrets/test_secrets.py index 41fac05b89a05e..1f43b0d5eeff28 100644 --- a/tests/secrets/test_secrets.py +++ b/tests/secrets/test_secrets.py @@ -19,7 +19,7 @@ import unittest from unittest import mock -from airflow.secrets import get_connections, initialize_secrets_backends +from airflow.secrets import ensure_secrets_loaded, get_connections, initialize_secrets_backends from tests.test_utils.config import conf_vars @@ -51,6 +51,28 @@ def test_initialize_secrets_backends(self): self.assertEqual(3, len(backends)) self.assertIn('AwsSsmSecretsBackend', backend_classes) + @conf_vars({ + ("secrets", "backend"): "airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend", + ("secrets", "backend_kwargs"): '{"prefix": "/airflow", "profile_name": null}', + }) + @mock.patch.dict('os.environ', { + 'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow', + }) + @mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri") + def test_backend_fallback_to_env_var(self, mock_get_uri): + mock_get_uri.return_value = None + + backends = ensure_secrets_loaded() + backend_classes = [backend.__class__.__name__ for backend in backends] + self.assertIn('AwsSsmSecretsBackend', backend_classes) + + uri = get_connections(conn_id="test_mysql") + + # Assert that AwsSsmSecretsBackend.get_conn_uri was called + mock_get_uri.assert_called_once_with(conn_id='test_mysql') + + self.assertEqual('mysql://airflow:airflow@host:5432/airflow', uri[0].get_uri()) + if __name__ == "__main__": unittest.main()