Skip to content

Commit

Permalink
[AIRFLOW-5705] Fix bugs in AWS SSM Secrets Backend (apache#7745)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored Mar 17, 2020
1 parent 968a3f9 commit 2a54512
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 16 deletions.
27 changes: 20 additions & 7 deletions airflow/providers/amazon/aws/secrets/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

from airflow.models import Connection
from airflow.secrets import CONN_ENV_PREFIX, BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin


class AwsSsmSecretsBackend(BaseSecretsBackend):
class AwsSsmSecretsBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection object from AWS SSM Parameter Store
Expand Down Expand Up @@ -66,26 +67,38 @@ def build_ssm_path(self, conn_id: str):
param_path = self.prefix + "/" + param_name
return param_path

def get_conn_uri(self, conn_id: str):
def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Get param value
:param conn_id: connection id
:type conn_id: str
"""
session = boto3.Session(profile_name=self.profile_name)
client = session.client("ssm")
response = client.get_parameter(
Name=self.build_ssm_path(conn_id=conn_id), WithDecryption=True
)
value = response["Parameter"]["Value"]
return value
ssm_path = self.build_ssm_path(conn_id=conn_id)
try:
response = client.get_parameter(
Name=ssm_path, WithDecryption=False
)
value = response["Parameter"]["Value"]
return value
except client.exceptions.ParameterNotFound:
self.log.info(
"An error occurred (ParameterNotFound) when calling the GetParameter operation: "
"Parameter %s not found.", ssm_path
)
return None

def get_connections(self, conn_id: str) -> List[Connection]:
"""
Create connection object.
:param conn_id: connection id
:type conn_id: str
"""
conn_uri = self.get_conn_uri(conn_id=conn_id)
if not conn_uri:
return []
conn = Connection(conn_id=conn_id, uri=conn_uri)
return [conn]
13 changes: 12 additions & 1 deletion airflow/secrets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_connections(conn_id: str) -> List[Connection]:
:param conn_id: connection id
:return: array of connections
"""
for secrets_backend in secrets_backend_list:
for secrets_backend in ensure_secrets_loaded():
conn_list = secrets_backend.get_connections(conn_id=conn_id)
if conn_list:
return list(conn_list)
Expand Down Expand Up @@ -100,4 +100,15 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]:
return backend_list


def ensure_secrets_loaded() -> List[BaseSecretsBackend]:
"""
Ensure that all secrets backends are loaded.
If the secrets_backend_list contains only 2 default backends, reload it.
"""
# Check if the secrets_backend_list contains only 2 default backends
if len(secrets_backend_list) == 2:
return initialize_secrets_backends()
return secrets_backend_list


secrets_backend_list = initialize_secrets_backends()
39 changes: 32 additions & 7 deletions tests/providers/amazon/aws/secrets/test_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,39 @@
# specific language governing permissions and limitations
# under the License.

from unittest import mock
from unittest import TestCase, mock

from moto import mock_ssm

from airflow.providers.amazon.aws.secrets.ssm import AwsSsmSecretsBackend
from airflow.secrets import get_connections


class TestSsmSecrets(TestCase):
@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
def test_aws_ssm_get_connections(self, mock_get_uri):
mock_get_uri.return_value = "scheme://user:pass@host:100"
conn_list = AwsSsmSecretsBackend().get_connections("fake_conn")
conn = conn_list[0]
assert conn.host == 'host'

@mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',
})
@mock_ssm
def test_get_conn_uri_non_existent_key(self):
"""
Test that if the key with connection ID is not present in SSM,
AwsSsmSecretsBackend.get_connections should return None and fallback to the
environment variable if it is set
"""
conn_id = "test_mysql"
test_client = AwsSsmSecretsBackend()

self.assertIsNone(test_client.get_conn_uri(conn_id=conn_id))
self.assertEqual([], test_client.get_connections(conn_id=conn_id))

@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
def test_aws_ssm_get_connections(mock_get_uri):
mock_get_uri.side_effect = ["scheme://user:pass@host:100"]
conn_list = AwsSsmSecretsBackend().get_connections("fake_conn")
conn = conn_list[0]
assert conn.host == 'host'
# Fallback to connection defined in Environment Variable
self.assertEqual(
"mysql://airflow:airflow@host:5432/airflow",
get_connections(conn_id="test_mysql")[0].get_uri())
24 changes: 23 additions & 1 deletion tests/secrets/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
from unittest import mock

from airflow.secrets import get_connections, initialize_secrets_backends
from airflow.secrets import ensure_secrets_loaded, get_connections, initialize_secrets_backends
from tests.test_utils.config import conf_vars


Expand Down Expand Up @@ -51,6 +51,28 @@ def test_initialize_secrets_backends(self):
self.assertEqual(3, len(backends))
self.assertIn('AwsSsmSecretsBackend', backend_classes)

@conf_vars({
("secrets", "backend"): "airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend",
("secrets", "backend_kwargs"): '{"prefix": "/airflow", "profile_name": null}',
})
@mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',
})
@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
def test_backend_fallback_to_env_var(self, mock_get_uri):
mock_get_uri.return_value = None

backends = ensure_secrets_loaded()
backend_classes = [backend.__class__.__name__ for backend in backends]
self.assertIn('AwsSsmSecretsBackend', backend_classes)

uri = get_connections(conn_id="test_mysql")

# Assert that AwsSsmSecretsBackend.get_conn_uri was called
mock_get_uri.assert_called_once_with(conn_id='test_mysql')

self.assertEqual('mysql://airflow:airflow@host:5432/airflow', uri[0].get_uri())


if __name__ == "__main__":
unittest.main()

0 comments on commit 2a54512

Please sign in to comment.