Skip to content

Commit

Permalink
[AIRFLOW-4574] add option to provide private_key in SSHHook (apache#6104
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dstandish authored and mik-laj committed Sep 17, 2019
1 parent c098ff7 commit fa8e18a
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 22 deletions.
57 changes: 38 additions & 19 deletions airflow/contrib/hooks/ssh_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import getpass
import os
import warnings
from io import StringIO

import paramiko
from paramiko.config import SSH_PORT
Expand All @@ -45,8 +46,10 @@ class SSHHook(BaseHook):
:type username: str
:param password: password of the username to connect to the remote_host
:type password: str
:param key_file: key file to use to connect to the remote_host.
:param key_file: path to key file to use to connect to the remote_host
:type key_file: str
:param private_key: content of key file to use to connect to remote_host
:type: str
:param port: port of remote host to connect (Default is paramiko SSH_PORT)
:type port: int
:param timeout: timeout for the attempt to connect to the remote_host.
Expand All @@ -62,6 +65,7 @@ def __init__(self,
username=None,
password=None,
key_file=None,
private_key=None,
port=None,
timeout=10,
keepalive_interval=30
Expand All @@ -71,6 +75,8 @@ def __init__(self,
self.username = username
self.password = password
self.key_file = key_file
self.private_key = private_key
self.pkey = None
self.port = port
self.timeout = timeout
self.keepalive_interval = keepalive_interval
Expand Down Expand Up @@ -100,6 +106,9 @@ def __init__(self,
if "key_file" in extra_options and self.key_file is None:
self.key_file = extra_options.get("key_file")

if not self.private_key:
self.private_key = extra_options.get('private_key')

if "timeout" in extra_options:
self.timeout = int(extra_options["timeout"], 10)

Expand All @@ -115,6 +124,13 @@ def __init__(self,
str(extra_options["allow_host_key_change"]).lower() == 'true':
self.allow_host_key_change = True

if self.private_key and self.key_file:
raise AirflowException(
"Params key_file and private_key both provided. Must provide no more than one.")

if self.private_key:
self.pkey = paramiko.RSAKey.from_private_key(StringIO(self.private_key))

if not self.remote_host:
raise AirflowException("Missing required param: remote_host")

Expand Down Expand Up @@ -150,6 +166,7 @@ def get_conn(self):

self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
client = paramiko.SSHClient()

if not self.allow_host_key_change:
self.log.warning('Remote Identification Change is not verified. '
'This wont protect against Man-In-The-Middle attacks')
Expand All @@ -159,24 +176,26 @@ def get_conn(self):
'against Man-In-The-Middle attacks')
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

if self.password and self.password.strip():
client.connect(hostname=self.remote_host,
username=self.username,
password=self.password,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)
else:
client.connect(hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)
connect_kwargs = dict(
hostname=self.remote_host,
username=self.username,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy
)

if self.password:
password = self.password.strip()
connect_kwargs.update(password=password)

if self.pkey:
connect_kwargs.update(pkey=self.pkey)

if self.key_file:
connect_kwargs.update(key_filename=self.key_file)

client.connect(**connect_kwargs)

if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)
Expand Down
9 changes: 8 additions & 1 deletion docs/howto/connection/ssh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Extra (optional)
are supported:

* **key_file** - Full Path of the private SSH Key file that will be used to connect to the remote_host.
* **private_key** - Content of the private key used to connect to the remote_host.
* **timeout** - An optional timeout (in seconds) for the TCP connect. Default is ``10``.
* **compress** - ``true`` to ask the remote client/server to compress traffic; `false` to refuse compression. Default is ``true``.
* **no_host_key_check** - Set to ``false`` to restrict connecting to hosts with no entries in ``~/.ssh/known_hosts`` (Hosts file). This provides maximum protection against trojan horse attacks, but can be troublesome when the ``/etc/ssh/ssh_known_hosts`` file is poorly maintained or connections to new hosts are frequently made. This option forces the user to manually add all new hosts. Default is ``true``, ssh will automatically add new host keys to the user known hosts files.
Expand All @@ -62,8 +63,14 @@ Extra (optional)
following the standard syntax of connections, where extras are passed as parameters
of the URI (note that all components of the URI should be URL-encoded).

For example:
Example connection string with ``key_file`` (path to key file provided in connection):

.. code-block:: bash
export AIRFLOW_CONN_MAIN_SERVER='ssh://user:pass@localhost:22?timeout=10&compress=false&no_host_key_check=false&allow_host_key_change=true&key_file=%2Fhome%2Fairflow%2F.ssh%2Fid_rsa'
Example connection string with ``private_key`` (actual private key provided in connection):

.. code-block:: bash
export AIRFLOW_CONN_SSH_SERVER='SSH://127.0.0.1?private_key=-----BEGIN+RSA+PRIVATE+KEY-----%0AMIIEpAIBAAKCAQEAvYUM9xouSUtCKMwm%2FkogT4r3Y%2Bh7H0IPnd7DF9sKCHt9FPJ%2B%0ALaQNX%2FRgnOoPf5ySN42A1nmqv4WX5AKdjEYMIJzN2g2whnol8RVjzP4s2Ao%2B%2BWJ9%0AKstey85CQUgjWFO57ye3TyhbfMZI3fBqDX5RjgkgAZmUpKmv6ttSiCfdgGxLweD7%0ADZexlAjuSfr7i0UZWBIbSKJdePMnWGvZZO%2BGerGlOIKs%2Bqx5agMbNJqDhWn0u8OV%0ACMANhc0yaUAbN08Pjac94%2FxmZPHASytrBmTGd6zYcuzOyxwK8KHMeLUagByT3u7l%0AvWcVyRx8FAXkl7nGF2SQZ0z3JLhmdWMSXuc1AQIDAQABAoIBAQC8%2Bp1REVQyVc8k%0A612%2Bl5%2FccU%2F62elb4%2F26iFS1xv8cMjcp2hwj2sBTfFWSYnsN3syWhI2CUFQJImex%0AP0Jmi7qwEmvaEWiCz%2B5hldisoo%2BI5b6h4qm5MI3YYFYEzrAf9W0kos%2FRKQcBRp%2BG%0AX6MAzYL5RPQbZE%2BqWmJGqGiFyGrBEISl%2FMdoaqSJewTRLHwDtbD9lt4WRPUO%2Font%0A%2FUKwOu3i9z5hMQm9HJJLuKr3hl5jmjJbJUg50a7fjVJzr52VfxH73Z%2Fst40fD3x4%0AH1DHGbX4ar9JOYvhzdXkuxyNXvoglJUIOiAk23Od8q9xOMQAITuwkc1QaVRXwiE7%0Aw41lMC8ZAoGBAOB9PEFyzGwYZgiReOQsAJrlwT7zsY053OGSXAeoLC2OzyLNb8v7%0AnKy2qoTMwxe9LHUDDAp6I8btprvLq35Y72iCbGg0ZK5fIYv%2Bt03NjvOOl1zEuUny%0A5xGe1IvP4YgMQuVMVw5dj11Jmna5eW3oFXlyOQrlth9hrexuI%2BG25qwvAoGBANgf%0AOhy%2FofyIgrIGwaRbgg55rlqViLNGFcJ6I3dVlsRqFxND4PvQZZWfCN3LhIGgI8cT%0AN6hFGPR9QrsmXe3eHM7%2FUpMk53oiPD9E0MemPtQh2AFPUb%2BznqxrXNGvtww6xYBM%0AKYLXcQVn%2FKELwwMYw3F0HGKgCFF0XthV34f%2Bt%2FXPAoGBALVLjqEQlBTsM2LSEP68%0AppRx3nn3lrmGNGMbryUj5OG6BoCFxrbG8gXt05JCR4Bhb4jkOBIyB7i87r2VQ19b%0AdaVCR0h0n6bO%2FymvQNwdmUgLLSRnX3hgKcpqKh7reKlFtbS2zUu1tXVSXuNo8K8Z%0AElatL3Ikh8uaODrLzECaVHpTAoGAXcReoC58h2Zq3faUeUzChqlAfki2gKF9u1zm%0AmlXmDd3BmTgwGtD14g6X%2BDLekKb8Htk1oqooA5t9IlmpExT1BtI7719pltHXtdOT%0AiauVQtBUOW1CmJvD0ibapJdKIeI14k4pDH2QqbnOH8lMmMFbupOX5SptsXl91Pqc%0A%2BxIGmn0CgYBOL2o0Sn%2F8d7uzAZKUBG1%2F0eFr4j6wYwWajVDFOfbJ7WdIf5j%2BL3nY%0A3440i%2Fb2NlEE8nLPDl6cwiOtwV0XFkoiF3ctHvutlhGBxAKHetIxIsnQk7vXqgfP%0AnhsgNypNAQXbxe3gjJEb4Fzw3Ufz3mq5PllYtXKhc%2Bmc4%2B3sN5uGow%3D%3D%0A-----END+RSA+PRIVATE+KEY-----%0A'
80 changes: 78 additions & 2 deletions tests/contrib/hooks/test_ssh_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
# under the License.

import unittest

import json
from io import StringIO
import paramiko
from airflow.models import Connection
from airflow.utils import db
from airflow.contrib.hooks.ssh_hook import SSHHook

from airflow.utils.db import create_session
from tests.compat import mock

HELLO_SERVER_CMD = """
Expand All @@ -38,7 +40,30 @@
"""


def generate_key_string(pkey: paramiko.PKey):
key_fh = StringIO()
pkey.write_private_key(key_fh)
key_fh.seek(0)
key_str = key_fh.read()
return key_str


TEST_PKEY = paramiko.RSAKey.generate(4096)
TEST_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY)


class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_PRIVATE_KEY_EXTRA = 'ssh_with_private_key_extra'

def tearDown(self) -> None:
with create_session() as session:
conns_to_reset = [
self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
]
connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
connections.delete(synchronize_session=False)
session.commit()

@mock.patch('airflow.contrib.hooks.ssh_hook.paramiko.SSHClient')
def test_ssh_connection_with_password(self, ssh_mock):
hook = SSHHook(remote_host='remote_host',
Expand Down Expand Up @@ -164,6 +189,57 @@ def test_tunnel(self):
server_handle.communicate()
self.assertEqual(server_handle.returncode, 0)

@mock.patch('airflow.contrib.hooks.ssh_hook.paramiko.SSHClient')
def test_ssh_connection_with_private_key(self, ssh_mock):
hook = SSHHook(remote_host='remote_host',
port='port',
username='username',
timeout=10,
private_key=TEST_PRIVATE_KEY)

with hook.get_conn():
ssh_mock.return_value.connect.assert_called_once_with(
hostname='remote_host',
username='username',
pkey=TEST_PKEY,
timeout=10,
compress=True,
port='port',
sock=None
)

@mock.patch('airflow.contrib.hooks.ssh_hook.paramiko.SSHClient')
def test_ssh_connection_with_private_key_extra(self, ssh_mock):
db.merge_conn(
Connection(
conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
host='localhost',
conn_type='ssh',
extra=json.dumps({
"private_key": TEST_PRIVATE_KEY,
})
)
)

hook = SSHHook(
ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
remote_host='remote_host',
port='port',
username='username',
timeout=10,
)

with hook.get_conn():
ssh_mock.return_value.connect.assert_called_once_with(
hostname='remote_host',
username='username',
pkey=TEST_PKEY,
timeout=10,
compress=True,
port='port',
sock=None
)


if __name__ == '__main__':
unittest.main()

0 comments on commit fa8e18a

Please sign in to comment.