Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-5705] Fix bugs in AWS SSM Secrets Backend #7745

Merged
merged 2 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions airflow/providers/amazon/aws/secrets/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

from airflow.models import Connection
from airflow.secrets import CONN_ENV_PREFIX, BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin


class AwsSsmSecretsBackend(BaseSecretsBackend):
class AwsSsmSecretsBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection object from AWS SSM Parameter Store

Expand Down Expand Up @@ -66,26 +67,38 @@ def build_ssm_path(self, conn_id: str):
param_path = self.prefix + "/" + param_name
return param_path

def get_conn_uri(self, conn_id: str):
def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Get param value

:param conn_id: connection id
:type conn_id: str
"""
session = boto3.Session(profile_name=self.profile_name)
client = session.client("ssm")
response = client.get_parameter(
Name=self.build_ssm_path(conn_id=conn_id), WithDecryption=True
)
value = response["Parameter"]["Value"]
return value
ssm_path = self.build_ssm_path(conn_id=conn_id)
try:
response = client.get_parameter(
Name=ssm_path, WithDecryption=False
)
value = response["Parameter"]["Value"]
return value
except client.exceptions.ParameterNotFound:
self.log.info(
"An error occurred (ParameterNotFound) when calling the GetParameter operation: "
"Parameter %s not found.", ssm_path
)
return None

def get_connections(self, conn_id: str) -> List[Connection]:
"""
Create connection object.

:param conn_id: connection id
:type conn_id: str
"""
conn_uri = self.get_conn_uri(conn_id=conn_id)
if not conn_uri:
return []
conn = Connection(conn_id=conn_id, uri=conn_uri)
return [conn]
13 changes: 12 additions & 1 deletion airflow/secrets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_connections(conn_id: str) -> List[Connection]:
:param conn_id: connection id
:return: array of connections
"""
for secrets_backend in secrets_backend_list:
for secrets_backend in ensure_secrets_loaded():
conn_list = secrets_backend.get_connections(conn_id=conn_id)
if conn_list:
return list(conn_list)
Expand Down Expand Up @@ -100,4 +100,15 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]:
return backend_list


def ensure_secrets_loaded() -> List[BaseSecretsBackend]:
"""
Ensure that all secrets backends are loaded.
If the secrets_backend_list contains only 2 default backends, reload it.
"""
# Check if the secrets_backend_list contains only 2 default backends
if len(secrets_backend_list) == 2:
return initialize_secrets_backends()
return secrets_backend_list


secrets_backend_list = initialize_secrets_backends()
39 changes: 32 additions & 7 deletions tests/providers/amazon/aws/secrets/test_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,39 @@
# specific language governing permissions and limitations
# under the License.

from unittest import mock
from unittest import TestCase, mock

from moto import mock_ssm

from airflow.providers.amazon.aws.secrets.ssm import AwsSsmSecretsBackend
from airflow.secrets import get_connections


class TestSsmSecrets(TestCase):
@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
def test_aws_ssm_get_connections(self, mock_get_uri):
mock_get_uri.return_value = "scheme://user:pass@host:100"
conn_list = AwsSsmSecretsBackend().get_connections("fake_conn")
conn = conn_list[0]
assert conn.host == 'host'

@mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',
})
@mock_ssm
def test_get_conn_uri_non_existent_key(self):
"""
Test that if the key with connection ID is not present in SSM,
AwsSsmSecretsBackend.get_connections should return None and fallback to the
environment variable if it is set
"""
conn_id = "test_mysql"
test_client = AwsSsmSecretsBackend()

self.assertIsNone(test_client.get_conn_uri(conn_id=conn_id))
self.assertEqual([], test_client.get_connections(conn_id=conn_id))

@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
def test_aws_ssm_get_connections(mock_get_uri):
mock_get_uri.side_effect = ["scheme://user:pass@host:100"]
conn_list = AwsSsmSecretsBackend().get_connections("fake_conn")
conn = conn_list[0]
assert conn.host == 'host'
# Fallback to connection defined in Environment Variable
self.assertEqual(
"mysql://airflow:airflow@host:5432/airflow",
get_connections(conn_id="test_mysql")[0].get_uri())
24 changes: 23 additions & 1 deletion tests/secrets/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
from unittest import mock

from airflow.secrets import get_connections, initialize_secrets_backends
from airflow.secrets import ensure_secrets_loaded, get_connections, initialize_secrets_backends
from tests.test_utils.config import conf_vars


Expand Down Expand Up @@ -51,6 +51,28 @@ def test_initialize_secrets_backends(self):
self.assertEqual(3, len(backends))
self.assertIn('AwsSsmSecretsBackend', backend_classes)

@conf_vars({
("secrets", "backend"): "airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend",
("secrets", "backend_kwargs"): '{"prefix": "/airflow", "profile_name": null}',
})
@mock.patch.dict('os.environ', {
'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',
})
@mock.patch("airflow.providers.amazon.aws.secrets.ssm.AwsSsmSecretsBackend.get_conn_uri")
def test_backend_fallback_to_env_var(self, mock_get_uri):
mock_get_uri.return_value = None

backends = ensure_secrets_loaded()
backend_classes = [backend.__class__.__name__ for backend in backends]
self.assertIn('AwsSsmSecretsBackend', backend_classes)

uri = get_connections(conn_id="test_mysql")

# Assert that AwsSsmSecretsBackend.get_conn_uri was called
mock_get_uri.assert_called_once_with(conn_id='test_mysql')

self.assertEqual('mysql://airflow:airflow@host:5432/airflow', uri[0].get_uri())


if __name__ == "__main__":
unittest.main()