Skip to content

Commit

Permalink
[AIRFLOW-7044] Host key can be specified via SSH connection extras. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
triptec authored Jan 8, 2021
1 parent 6570df8 commit 52339a5
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Increase size of connection.extra field to handle multiple RSA keys
Revision ID: 449b4072c2da
Revises: e959f08ac86c
Create Date: 2020-03-16 19:02:55.337710
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = '449b4072c2da'
down_revision = 'e959f08ac86c'
branch_labels = None
depends_on = None


def upgrade():
"""Apply increase_length_for_connection_password"""
with op.batch_alter_table('connection', schema=None) as batch_op:
batch_op.alter_column(
'extra',
existing_type=sa.VARCHAR(length=5000),
type_=sa.TEXT(),
existing_nullable=True,
)


def downgrade():
"""Unapply increase_length_for_connection_password"""
with op.batch_alter_table('connection', schema=None) as batch_op:
batch_op.alter_column(
'extra',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=5000),
existing_nullable=True,
)
2 changes: 1 addition & 1 deletion airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Connection(Base, LoggingMixin): # pylint: disable=too-many-instance-attri
port = Column(Integer())
is_encrypted = Column(Boolean, unique=False, default=False)
is_extra_encrypted = Column(Boolean, unique=False, default=False)
_extra = Column('extra', String(5000))
_extra = Column('extra', Text())

def __init__( # pylint: disable=too-many-arguments
self,
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def get_conn(self) -> pysftp.Connection:
cnopts = pysftp.CnOpts()
if self.no_host_key_check:
cnopts.hostkeys = None
else:
if self.host_key is not None:
cnopts.hostkeys.add(self.remote_host, 'ssh-rsa', self.host_key)
else:
pass # will fallback to system host keys if none explicitly specified in conn extra

cnopts.compression = self.compress
cnopts.ciphers = self.ciphers
conn_params = {
Expand Down
18 changes: 15 additions & 3 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import getpass
import os
import warnings
from base64 import decodebytes
from io import StringIO
from typing import Dict, Optional, Tuple, Union

Expand All @@ -30,7 +31,7 @@
from airflow.hooks.base import BaseHook


class SSHHook(BaseHook):
class SSHHook(BaseHook): # pylint: disable=too-many-instance-attributes
"""
Hook for ssh remote execution using Paramiko.
ref: https://github.com/paramiko/paramiko
Expand Down Expand Up @@ -72,7 +73,7 @@ def get_ui_field_behaviour() -> Dict:
},
}

def __init__(
def __init__( # pylint: disable=too-many-statements
self,
ssh_conn_id: Optional[str] = None,
remote_host: Optional[str] = None,
Expand All @@ -99,6 +100,7 @@ def __init__(
self.no_host_key_check = True
self.allow_host_key_change = False
self.host_proxy = None
self.host_key = None
self.look_for_keys = True

# Placeholder for deprecated __enter__
Expand Down Expand Up @@ -149,7 +151,9 @@ def __init__(
and str(extra_options["look_for_keys"]).lower() == 'false'
):
self.look_for_keys = False

if "host_key" in extra_options and self.no_host_key_check is False:

This comment has been minimized.

Copy link
@malthe

malthe Jun 8, 2021

Contributor

What is the logic here? If you provide a host_key then surely that is the key to use. I think the no_host_key_check value should be ignored in this case.

I have proposed a change as part of #16314.

decoded_host_key = decodebytes(extra_options["host_key"].encode('utf-8'))
self.host_key = paramiko.RSAKey(data=decoded_host_key)
if self.pkey and self.key_file:
raise AirflowException(
"Params key_file and private_key both provided. Must provide no more than one."
Expand Down Expand Up @@ -198,10 +202,18 @@ def get_conn(self) -> paramiko.SSHClient:
'This wont protect against Man-In-The-Middle attacks'
)
client.load_system_host_keys()

if self.no_host_key_check:
self.log.warning('No Host Key Verification. This wont protect against Man-In-The-Middle attacks')
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
else:
if self.host_key is not None:
client_host_keys = client.get_host_keys()
client_host_keys.add(self.remote_host, 'ssh-rsa', self.host_key)
else:
pass # will fallback to system host keys if none explicitly specified in conn extra

connect_kwargs = dict(
hostname=self.remote_host,
username=self.username,
Expand Down
6 changes: 4 additions & 2 deletions docs/apache-airflow-providers-ssh/connections/ssh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ Extra (optional)
* ``private_key_passphrase`` - Content of the private key passphrase used to decrypt the private key.
* ``timeout`` - An optional timeout (in seconds) for the TCP connect. Default is ``10``.
* ``compress`` - ``true`` to ask the remote client/server to compress traffic; ``false`` to refuse compression. Default is ``true``.
* ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with no entries in ``~/.ssh/known_hosts`` (Hosts file). This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new host keys to the user known hosts files.
* ``no_host_key_check`` - Set to ``false`` to restrict connecting to hosts with either no entries in ``~/.ssh/known_hosts`` (Hosts file) or not present in the ``host_key`` extra. This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new host keys to the user known hosts files.
* ``allow_host_key_change`` - Set to ``true`` if you want to allow connecting to hosts that has host key changed or when you get 'REMOTE HOST IDENTIFICATION HAS CHANGED' error. This wont protect against Man-In-The-Middle attacks. Other possible solution is to remove the host entry from ``~/.ssh/known_hosts`` file. Default is ``false``.
* ``look_for_keys`` - Set to ``false`` if you want to disable searching for discoverable private key files in ``~/.ssh/``
* ``host_key`` - The base64 encoded ssh-rsa public key of the host, as you would find in the ``known_hosts`` file. Specifying this, along with ``no_host_key_check=False`` allows you to only make the connection if the public key of the endpoint matches this value.

Example "extras" field:

Expand All @@ -59,9 +60,10 @@ Extra (optional)
"key_file": "/home/airflow/.ssh/id_rsa",
"timeout": "10",
"compress": "false",
"look_for_keys": "false",
"no_host_key_check": "false",
"allow_host_key_change": "false",
"look_for_keys": "false"
"host_key": "AAAHD...YDWwq=="
}
When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,7 @@ rootcss
rowid
rpc
rshift
rsa
rst
rtype
ru
Expand Down
41 changes: 40 additions & 1 deletion tests/providers/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,40 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import json
import os
import shutil
import unittest
from io import StringIO
from unittest import mock

import paramiko
import pysftp
from parameterized import parameterized

from airflow.models import Connection
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.utils.session import provide_session


def generate_host_key(pkey: paramiko.PKey):
key_fh = StringIO()
pkey.write_private_key(key_fh)
key_fh.seek(0)
key_obj = paramiko.RSAKey(file_obj=key_fh)
return key_obj.get_base64()


TMP_PATH = '/tmp'
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'

SFTP_CONNECTION_USER = "root"

TEST_PKEY = paramiko.RSAKey.generate(4096)
TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY)


class TestSFTPHook(unittest.TestCase):
@provide_session
Expand Down Expand Up @@ -178,6 +192,31 @@ def test_no_host_key_check_no_ignore(self, get_connection):
hook = SFTPHook()
self.assertEqual(hook.no_host_key_check, False)

@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_host_key_default(self, get_connection):
connection = Connection(login='login', host='host')
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.host_key, None)

@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_host_key(self, get_connection):
connection = Connection(
login='login',
host='host',
extra=json.dumps({"host_key": TEST_HOST_KEY, "no_host_key_check": False}),
)
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.host_key.get_base64(), TEST_HOST_KEY)

@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_host_key_with_no_host_key_check(self, get_connection):
connection = Connection(login='login', host='host', extra=json.dumps({"host_key": TEST_HOST_KEY}))
get_connection.return_value = connection
hook = SFTPHook()
self.assertEqual(hook.host_key, None)

@parameterized.expand(
[
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
Expand Down
93 changes: 93 additions & 0 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,17 @@ def generate_key_string(pkey: paramiko.PKey, passphrase: Optional[str] = None):
return key_str


def generate_host_key(pkey: paramiko.PKey):
key_fh = StringIO()
pkey.write_private_key(key_fh)
key_fh.seek(0)
key_obj = paramiko.RSAKey(file_obj=key_fh)
return key_obj.get_base64()


TEST_PKEY = paramiko.RSAKey.generate(4096)
TEST_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY)
TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY)

PASSPHRASE = ''.join(random.choice(string.ascii_letters) for i in range(10))
TEST_ENCRYPTED_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY, passphrase=PASSPHRASE)
Expand All @@ -63,13 +72,22 @@ class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA = 'ssh_with_private_key_passphrase_extra'
CONN_SSH_WITH_EXTRA = 'ssh_with_extra'
CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS = 'ssh_with_extra_false_look_for_keys'
CONN_SSH_WITH_HOST_KEY_EXTRA = 'ssh_with_host_key_extra'
CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_host_key_and_no_host_key_check_false'
CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_host_key_and_no_host_key_check_true'
CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_no_host_key_and_no_host_key_check_false'

@classmethod
def tearDownClass(cls) -> None:
with create_session() as session:
conns_to_reset = [
cls.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
cls.CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA,
cls.CONN_SSH_WITH_EXTRA,
cls.CONN_SSH_WITH_HOST_KEY_EXTRA,
cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
]
connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
connections.delete(synchronize_session=False)
Expand Down Expand Up @@ -116,6 +134,42 @@ def setUpClass(cls) -> None:
),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_HOST_KEY_EXTRA,
host='localhost',
conn_type='ssh',
extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY}),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
host='remote_host',
conn_type='ssh',
extra=json.dumps(
{"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": False}
),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
host='remote_host',
conn_type='ssh',
extra=json.dumps(
{"private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": True}
),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
host='remote_host',
conn_type='ssh',
extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": False}),
)
)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_password(self, ssh_mock):
Expand Down Expand Up @@ -344,3 +398,42 @@ def test_ssh_connection_with_private_key_passphrase_extra(self, ssh_mock):
sock=None,
look_for_keys=True,
)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_host_key_extra(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA)
assert hook.host_key is None # Since default no_host_key_check = True unless explicit override
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.get_host_keys.return_value.add.called is False

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_host_key_where_no_host_key_check_is_true(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE)
assert hook.host_key is None
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.get_host_keys.return_value.add.called is False

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_host_key_where_no_host_key_check_is_false(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
assert hook.host_key.get_base64() == TEST_HOST_KEY
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.get_host_keys.return_value.add.called is True
assert ssh_client.return_value.get_host_keys.return_value.add.call_args == mock.call(
hook.remote_host, 'ssh-rsa', hook.host_key
)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
assert hook.host_key is None
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.get_host_keys.return_value.add.called is False


if __name__ == '__main__':
unittest.main()

0 comments on commit 52339a5

Please sign in to comment.