Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disabled_algorithms as an extra parameter for SSH connections #24090

Merged
merged 4 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class SSHHook(BaseHook):
:param keepalive_interval: send a keepalive packet to remote host every
keepalive_interval seconds
:param banner_timeout: timeout to wait for banner from the server in seconds
:param disabled_algorithms: dictionary mapping algorithm type to an
iterable of algorithm identifiers, which will be disabled for the
lifetime of the transport
"""

# List of classes to try loading private keys as, ordered (roughly) by most common to least common
Expand Down Expand Up @@ -112,6 +115,7 @@ def __init__(
conn_timeout: Optional[int] = None,
keepalive_interval: int = 30,
banner_timeout: float = 30.0,
disabled_algorithms: Optional[dict] = None,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
Expand All @@ -125,6 +129,7 @@ def __init__(
self.conn_timeout = conn_timeout
self.keepalive_interval = keepalive_interval
self.banner_timeout = banner_timeout
self.disabled_algorithms = disabled_algorithms
self.host_proxy_cmd = None

# Default values, overridable from Connection
Expand Down Expand Up @@ -197,6 +202,9 @@ def __init__(
):
self.look_for_keys = False

if "disabled_algorithms" in extra_options:
self.disabled_algorithms = extra_options.get("disabled_algorithms")

if host_key is not None:
if host_key.startswith("ssh-"):
key_type, host_key = host_key.split(None)[:2]
Expand Down Expand Up @@ -313,6 +321,9 @@ def get_conn(self) -> paramiko.SSHClient:
if self.key_file:
connect_kwargs.update(key_filename=self.key_file)

if self.disabled_algorithms:
connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)

log_before_sleep = lambda retry_state: self.log.info(
"Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number
)
Expand Down
2 changes: 2 additions & 0 deletions docs/apache-airflow-providers-ssh/connections/ssh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Extra (optional)
* ``allow_host_key_change`` - Set to ``true`` if you want to allow connecting to hosts that has host key changed or when you get 'REMOTE HOST IDENTIFICATION HAS CHANGED' error. This won't protect against Man-In-The-Middle attacks. Other possible solution is to remove the host entry from ``~/.ssh/known_hosts`` file. Default is ``false``.
* ``look_for_keys`` - Set to ``false`` if you want to disable searching for discoverable private key files in ``~/.ssh/``
* ``host_key`` - The base64 encoded ssh-rsa public key of the host or "ssh-<key type> <key data>" (as you would find in the ``known_hosts`` file). Specifying this allows making the connection if and only if the public key of the endpoint matches this value.
* ``disabled_algorithms`` - A dictionary mapping algorithm type to an iterable of algorithm identifiers, which will be disabled for the lifetime of the transport.

Example "extras" field:

Expand All @@ -66,6 +67,7 @@ Extra (optional)
"look_for_keys": "false",
"allow_host_key_change": "false",
"host_key": "AAAHD...YDWwq=="
"disabled_algorithms": {"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]}
}

When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it
Expand Down
34 changes: 34 additions & 0 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def generate_host_key(pkey: paramiko.PKey):
PASSPHRASE = ''.join(random.choice(string.ascii_letters) for i in range(10))
TEST_ENCRYPTED_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY, passphrase=PASSPHRASE)

TEST_DISABLED_ALGORITHMS = {"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]}


class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_NO_EXTRA = 'ssh_with_no_extra'
Expand All @@ -96,6 +98,7 @@ class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE = (
'ssh_with_host_key_and_allow_host_key_changes_true'
)
CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS = 'ssh_with_extra_disabled_algorithms'

@classmethod
def tearDownClass(cls) -> None:
Expand All @@ -115,6 +118,7 @@ def tearDownClass(cls) -> None:
cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
cls.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS,
]
connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
connections.delete(synchronize_session=False)
Expand Down Expand Up @@ -263,6 +267,14 @@ def setUpClass(cls) -> None:
),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS,
host='localhost',
conn_type='ssh',
extra=json.dumps({"disabled_algorithms": TEST_DISABLED_ALGORITHMS}),
)
)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_password(self, ssh_mock):
Expand Down Expand Up @@ -747,6 +759,28 @@ def test_ssh_connection_with_all_timeout_param_and_extra_combinations(
look_for_keys=True,
)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_with_extra_disabled_algorithms(self, ssh_mock):
hook = SSHHook(
ssh_conn_id=self.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS,
remote_host='remote_host',
port='port',
username='username',
)

with hook.get_conn():
ssh_mock.return_value.connect.assert_called_once_with(
banner_timeout=30.0,
hostname='remote_host',
username='username',
compress=True,
timeout=10,
port='port',
sock=None,
look_for_keys=True,
disabled_algorithms=TEST_DISABLED_ALGORITHMS,
)

def test_openssh_private_key(self):
# Paramiko behaves differently with OpenSSH generated keys to paramiko
# generated keys, so we need a test one.
Expand Down