diff --git a/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py b/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py new file mode 100644 index 00000000000000..d3d9432fe2f46e --- /dev/null +++ b/airflow/migrations/versions/449b4072c2da_increase_size_of_connection_extra_field_.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Increase size of connection.extra field to handle multiple RSA keys + +Revision ID: 449b4072c2da +Revises: e959f08ac86c +Create Date: 2020-03-16 19:02:55.337710 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '449b4072c2da' +down_revision = 'e959f08ac86c' +branch_labels = None +depends_on = None + + +def upgrade(): + """Apply increase_length_for_connection_password""" + with op.batch_alter_table('connection', schema=None) as batch_op: + batch_op.alter_column( + 'extra', + existing_type=sa.VARCHAR(length=5000), + type_=sa.TEXT(), + existing_nullable=True, + ) + + +def downgrade(): + """Unapply increase_length_for_connection_password""" + with op.batch_alter_table('connection', schema=None) as batch_op: + batch_op.alter_column( + 'extra', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=5000), + existing_nullable=True, + ) diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 1159a44a906489..c030571fbd4ef3 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -102,7 +102,7 @@ class Connection(Base, LoggingMixin): # pylint: disable=too-many-instance-attri port = Column(Integer()) is_encrypted = Column(Boolean, unique=False, default=False) is_extra_encrypted = Column(Boolean, unique=False, default=False) - _extra = Column('extra', String(5000)) + _extra = Column('extra', Text()) def __init__( # pylint: disable=too-many-arguments self, diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py index 498f362c10179b..e2a991e8704f2a 100644 --- a/airflow/providers/sftp/hooks/sftp.py +++ b/airflow/providers/sftp/hooks/sftp.py @@ -115,6 +115,12 @@ def get_conn(self) -> pysftp.Connection: cnopts = pysftp.CnOpts() if self.no_host_key_check: cnopts.hostkeys = None + else: + if self.host_key is not None: + cnopts.hostkeys.add(self.remote_host, 'ssh-rsa', self.host_key) + else: + pass # will fallback to system host keys if none explicitly specified in conn extra + cnopts.compression = self.compress cnopts.ciphers = self.ciphers conn_params = { diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py index d420b1bf271d7d..1b35db31fd54a2 100644 --- a/airflow/providers/ssh/hooks/ssh.py +++ b/airflow/providers/ssh/hooks/ssh.py @@ -19,6 +19,7 @@ import getpass import os import warnings +from base64 import decodebytes from io import StringIO from typing import Dict, Optional, Tuple, Union @@ -30,7 +31,7 @@ from airflow.hooks.base import BaseHook -class SSHHook(BaseHook): +class SSHHook(BaseHook): # pylint: disable=too-many-instance-attributes """ Hook for ssh remote execution using Paramiko. ref: https://github.com/paramiko/paramiko @@ -72,7 +73,7 @@ def get_ui_field_behaviour() -> Dict: }, } - def __init__( + def __init__( # pylint: disable=too-many-statements self, ssh_conn_id: Optional[str] = None, remote_host: Optional[str] = None, @@ -99,6 +100,7 @@ def __init__( self.no_host_key_check = True self.allow_host_key_change = False self.host_proxy = None + self.host_key = None self.look_for_keys = True # Placeholder for deprecated __enter__ @@ -149,7 +151,9 @@ def __init__( and str(extra_options["look_for_keys"]).lower() == 'false' ): self.look_for_keys = False - + if "host_key" in extra_options and self.no_host_key_check is False: + decoded_host_key = decodebytes(extra_options["host_key"].encode('utf-8')) + self.host_key = paramiko.RSAKey(data=decoded_host_key) if self.pkey and self.key_file: raise AirflowException( "Params key_file and private_key both provided. Must provide no more than one." @@ -198,10 +202,18 @@ def get_conn(self) -> paramiko.SSHClient: 'This wont protect against Man-In-The-Middle attacks' ) client.load_system_host_keys() + if self.no_host_key_check: self.log.warning('No Host Key Verification. This wont protect against Man-In-The-Middle attacks') # Default is RejectPolicy client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + else: + if self.host_key is not None: + client_host_keys = client.get_host_keys() + client_host_keys.add(self.remote_host, 'ssh-rsa', self.host_key) + else: + pass # will fallback to system host keys if none explicitly specified in conn extra + connect_kwargs = dict( hostname=self.remote_host, username=self.username, diff --git a/docs/apache-airflow-providers-ssh/connections/ssh.rst b/docs/apache-airflow-providers-ssh/connections/ssh.rst index 54e902e36a5828..f320381904fd0e 100644 --- a/docs/apache-airflow-providers-ssh/connections/ssh.rst +++ b/docs/apache-airflow-providers-ssh/connections/ssh.rst @@ -47,9 +47,10 @@ Extra (optional) * ``private_key_passphrase`` - Content of the private key passphrase used to decrypt the private key. * ``timeout`` - An optional timeout (in seconds) for the TCP connect. Default is ``10``. * ``compress`` - ``true`` to ask the remote client/server to compress traffic; ``false`` to refuse compression. Default is ``true``. - * ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with no entries in ``~/.ssh/known_hosts`` (Hosts file). This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new host keys to the user known hosts files. + * ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with either no entries in ``~/.ssh/known_hosts`` (Hosts file) or not present in the ``host_key`` extra. This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new host keys to the user known hosts files. * ``allow_host_key_change`` - Set to ``true`` if you want to allow connecting to hosts that has host key changed or when you get 'REMOTE HOST IDENTIFICATION HAS CHANGED' error. This wont protect against Man-In-The-Middle attacks. Other possible solution is to remove the host entry from ``~/.ssh/known_hosts`` file. Default is ``false``. * ``look_for_keys`` - Set to ``false`` if you want to disable searching for discoverable private key files in ``~/.ssh/`` + * ``host_key`` - The base64 encoded ssh-rsa public key of the host, as you would find in the ``known_hosts`` file. Specifying this, along with ``no_host_key_check=False`` allows you to only make the connection if the public key of the endpoint matches this value. Example "extras" field: @@ -59,9 +60,10 @@ Extra (optional) "key_file": "/home/airflow/.ssh/id_rsa", "timeout": "10", "compress": "false", + "look_for_keys": "false", "no_host_key_check": "false", "allow_host_key_change": "false", - "look_for_keys": "false" + "host_key": "AAAHD...YDWwq==" } When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 238021e18a2664..84f1860a913d67 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1157,6 +1157,7 @@ rootcss rowid rpc rshift +rsa rst rtype ru diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py index 45097e6c476657..9211c30abc254a 100644 --- a/tests/providers/sftp/hooks/test_sftp.py +++ b/tests/providers/sftp/hooks/test_sftp.py @@ -15,12 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import json import os import shutil import unittest +from io import StringIO from unittest import mock +import paramiko import pysftp from parameterized import parameterized @@ -28,6 +30,15 @@ from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.utils.session import provide_session + +def generate_host_key(pkey: paramiko.PKey): + key_fh = StringIO() + pkey.write_private_key(key_fh) + key_fh.seek(0) + key_obj = paramiko.RSAKey(file_obj=key_fh) + return key_obj.get_base64() + + TMP_PATH = '/tmp' TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir' SUB_DIR = "sub_dir" @@ -35,6 +46,9 @@ SFTP_CONNECTION_USER = "root" +TEST_PKEY = paramiko.RSAKey.generate(4096) +TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY) + class TestSFTPHook(unittest.TestCase): @provide_session @@ -178,6 +192,31 @@ def test_no_host_key_check_no_ignore(self, get_connection): hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False) + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_host_key_default(self, get_connection): + connection = Connection(login='login', host='host') + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.host_key, None) + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_host_key(self, get_connection): + connection = Connection( + login='login', + host='host', + extra=json.dumps({"host_key": TEST_HOST_KEY, "no_host_key_check": False}), + ) + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.host_key.get_base64(), TEST_HOST_KEY) + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_host_key_with_no_host_key_check(self, get_connection): + connection = Connection(login='login', host='host', extra=json.dumps({"host_key": TEST_HOST_KEY})) + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.host_key, None) + @parameterized.expand( [ (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py index 027de40c639a27..fea52bc5e01b68 100644 --- a/tests/providers/ssh/hooks/test_ssh.py +++ b/tests/providers/ssh/hooks/test_ssh.py @@ -51,8 +51,17 @@ def generate_key_string(pkey: paramiko.PKey, passphrase: Optional[str] = None): return key_str +def generate_host_key(pkey: paramiko.PKey): + key_fh = StringIO() + pkey.write_private_key(key_fh) + key_fh.seek(0) + key_obj = paramiko.RSAKey(file_obj=key_fh) + return key_obj.get_base64() + + TEST_PKEY = paramiko.RSAKey.generate(4096) TEST_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY) +TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY) PASSPHRASE = ''.join(random.choice(string.ascii_letters) for i in range(10)) TEST_ENCRYPTED_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY, passphrase=PASSPHRASE) @@ -63,6 +72,10 @@ class TestSSHHook(unittest.TestCase): CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA = 'ssh_with_private_key_passphrase_extra' CONN_SSH_WITH_EXTRA = 'ssh_with_extra' CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS = 'ssh_with_extra_false_look_for_keys' + CONN_SSH_WITH_HOST_KEY_EXTRA = 'ssh_with_host_key_extra' + CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_host_key_and_no_host_key_check_false' + CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_host_key_and_no_host_key_check_true' + CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_no_host_key_and_no_host_key_check_false' @classmethod def tearDownClass(cls) -> None: @@ -70,6 +83,11 @@ def tearDownClass(cls) -> None: conns_to_reset = [ cls.CONN_SSH_WITH_PRIVATE_KEY_EXTRA, cls.CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA, + cls.CONN_SSH_WITH_EXTRA, + cls.CONN_SSH_WITH_HOST_KEY_EXTRA, + cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, + cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, + cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, ] connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset)) connections.delete(synchronize_session=False) @@ -116,6 +134,42 @@ def setUpClass(cls) -> None: ), ) ) + db.merge_conn( + Connection( + conn_id=cls.CONN_SSH_WITH_HOST_KEY_EXTRA, + host='localhost', + conn_type='ssh', + extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY}), + ) + ) + db.merge_conn( + Connection( + conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, + host='remote_host', + conn_type='ssh', + extra=json.dumps( + {"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": False} + ), + ) + ) + db.merge_conn( + Connection( + conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, + host='remote_host', + conn_type='ssh', + extra=json.dumps( + {"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": True} + ), + ) + ) + db.merge_conn( + Connection( + conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, + host='remote_host', + conn_type='ssh', + extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": False}), + ) + ) @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') def test_ssh_connection_with_password(self, ssh_mock): @@ -344,3 +398,42 @@ def test_ssh_connection_with_private_key_passphrase_extra(self, ssh_mock): sock=None, look_for_keys=True, ) + + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') + def test_ssh_connection_with_host_key_extra(self, ssh_client): + hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA) + assert hook.host_key is None # Since default no_host_key_check = True unless explicit override + with hook.get_conn(): + assert ssh_client.return_value.connect.called is True + assert ssh_client.return_value.get_host_keys.return_value.add.called is False + + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') + def test_ssh_connection_with_host_key_where_no_host_key_check_is_true(self, ssh_client): + hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE) + assert hook.host_key is None + with hook.get_conn(): + assert ssh_client.return_value.connect.called is True + assert ssh_client.return_value.get_host_keys.return_value.add.called is False + + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') + def test_ssh_connection_with_host_key_where_no_host_key_check_is_false(self, ssh_client): + hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE) + assert hook.host_key.get_base64() == TEST_HOST_KEY + with hook.get_conn(): + assert ssh_client.return_value.connect.called is True + assert ssh_client.return_value.get_host_keys.return_value.add.called is True + assert ssh_client.return_value.get_host_keys.return_value.add.call_args == mock.call( + hook.remote_host, 'ssh-rsa', hook.host_key + ) + + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') + def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self, ssh_client): + hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE) + assert hook.host_key is None + with hook.get_conn(): + assert ssh_client.return_value.connect.called is True + assert ssh_client.return_value.get_host_keys.return_value.add.called is False + + +if __name__ == '__main__': + unittest.main()