Skip to content

Commit

Permalink
[AIRFLOW-3984] Add tests for WinRMHook (#4811)
Browse files Browse the repository at this point in the history
- fix docs
- refactoring code

[AIRFLOW-3984] Change docs to be clearer about None values
  • Loading branch information
feluelle authored and Fokko committed Mar 6, 2019
1 parent f3156d7 commit 92638e8
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 39 deletions.
62 changes: 23 additions & 39 deletions airflow/contrib/hooks/winrm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,18 @@ class WinRMHook(BaseHook):
:seealso: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
:param ssh_conn_id: connection id from airflow Connections from where all
the required parameters can be fetched like username and password.
:param ssh_conn_id: connection id from airflow Connections from where
all the required parameters can be fetched like username and password.
Thought the priority is given to the param passed during init
:type ssh_conn_id: str
:param endpoint: When set to `None`, endpoint will be constructed like this:
:param endpoint: When not set, endpoint will be constructed like this:
'http://{remote_host}:{remote_port}/wsman'
:type endpoint: str
:param remote_host: Remote host to connect to.
Ignored if `endpoint` is not `None`.
:param remote_host: Remote host to connect to. Ignored if `endpoint` is set.
:type remote_host: str
:param remote_port: Remote port to connect to.
Ignored if `endpoint` is not `None`.
:param remote_port: Remote port to connect to. Ignored if `endpoint` is set.
:type remote_port: int
:param transport: transport type, one of 'plaintext' (default), 'kerberos', 'ssl',
'ntlm', 'credssp'
:param transport: transport type, one of 'plaintext' (default), 'kerberos', 'ssl', 'ntlm', 'credssp'
:type transport: str
:param username: username to connect to the remote_host
:type username: str
Expand All @@ -63,32 +60,33 @@ class WinRMHook(BaseHook):
:param cert_key_pem: client authentication certificate key file path in PEM format
:type cert_key_pem: str
:param server_cert_validation: whether server certificate should be validated on
Python versions that suppport it; one of 'validate' (default), 'ignore'
Python versions that support it; one of 'validate' (default), 'ignore'
:type server_cert_validation: str
:param kerberos_delegation: if True, TGT is sent to target server to
allow multiple hops
:type kerberos_delegation: bool
:param read_timeout_sec: maximum seconds to wait before an HTTP connect/read times out
(default 30). This value should be slightly higher than operation_timeout_sec,
:param read_timeout_sec: maximum seconds to wait before an HTTP connect/read times out (default 30).
This value should be slightly higher than operation_timeout_sec,
as the server can block *at least* that long.
:type read_timeout_sec: int
:param operation_timeout_sec: maximum allowed time in seconds for any single wsman
HTTP operation (default 20). Note that operation timeouts while receiving output
(the only wsman operation that should take any significant time, and where these
timeouts are expected) will be silently retried indefinitely.
(the only wsman operation that should take any significant time,
and where these timeouts are expected) will be silently retried indefinitely.
:type operation_timeout_sec: int
:param kerberos_hostname_override: the hostname to use for the kerberos exchange
(defaults to the hostname in the endpoint URL)
:type kerberos_hostname_override: str
:param message_encryption_enabled: Will encrypt the WinRM messages if set to True and
the transport auth supports message encryption (Default True).
:type message_encryption_enabled: bool
:param message_encryption: Will encrypt the WinRM messages if set
and the transport auth supports message encryption. (Default 'auto')
:type message_encryption: str
:param credssp_disable_tlsv1_2: Whether to disable TLSv1.2 support and work with older
protocols like TLSv1.0, default is False
:type credssp_disable_tlsv1_2: bool
:param send_cbt: Will send the channel bindings over a HTTPS channel (Default: True)
:type send_cbt: bool
"""

def __init__(self,
ssh_conn_id=None,
endpoint=None,
Expand All @@ -109,8 +107,7 @@ def __init__(self,
kerberos_hostname_override=None,
message_encryption='auto',
credssp_disable_tlsv1_2=False,
send_cbt=True,
):
send_cbt=True):
super(WinRMHook, self).__init__(ssh_conn_id)
self.ssh_conn_id = ssh_conn_id
self.endpoint = endpoint
Expand Down Expand Up @@ -171,19 +168,15 @@ def get_conn(self):
if "cert_key_pem" in extra_options:
self.cert_key_pem = str(extra_options["cert_key_pem"])
if "server_cert_validation" in extra_options:
self.server_cert_validation = \
str(extra_options["server_cert_validation"])
self.server_cert_validation = str(extra_options["server_cert_validation"])
if "kerberos_delegation" in extra_options:
self.kerberos_delegation = \
str(extra_options["kerberos_delegation"]).lower() == 'true'
self.kerberos_delegation = str(extra_options["kerberos_delegation"]).lower() == 'true'
if "read_timeout_sec" in extra_options:
self.read_timeout_sec = int(extra_options["read_timeout_sec"])
if "operation_timeout_sec" in extra_options:
self.operation_timeout_sec = \
int(extra_options["operation_timeout_sec"])
self.operation_timeout_sec = int(extra_options["operation_timeout_sec"])
if "kerberos_hostname_override" in extra_options:
self.kerberos_hostname_override = \
str(extra_options["kerberos_hostname_override"])
self.kerberos_hostname_override = str(extra_options["kerberos_hostname_override"])
if "message_encryption" in extra_options:
self.message_encryption = str(extra_options["message_encryption"])
if "credssp_disable_tlsv1_2" in extra_options:
Expand All @@ -206,10 +199,7 @@ def get_conn(self):

# If endpoint is not set, then build a standard wsman endpoint from host and port.
if not self.endpoint:
self.endpoint = 'http://{0}:{1}/wsman'.format(
self.remote_host,
self.remote_port
)
self.endpoint = 'http://{0}:{1}/wsman'.format(self.remote_host, self.remote_port)

try:
if self.password and self.password.strip():
Expand All @@ -233,17 +223,11 @@ def get_conn(self):
send_cbt=self.send_cbt
)

self.log.info(
"Establishing WinRM connection to host: %s",
self.remote_host
)
self.log.info("Establishing WinRM connection to host: %s", self.remote_host)
self.client = self.winrm_protocol.open_shell()

except Exception as error:
error_msg = "Error connecting to host: {0}, error: {1}".format(
self.remote_host,
error
)
error_msg = "Error connecting to host: {0}, error: {1}".format(self.remote_host, error)
self.log.error(error_msg)
raise AirflowException(error_msg)

Expand Down
120 changes: 120 additions & 0 deletions tests/contrib/hooks/test_winrm_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import unittest

from mock import patch

from airflow import AirflowException
from airflow.contrib.hooks.winrm_hook import WinRMHook
from airflow.models.connection import Connection


class TestWinRMHook(unittest.TestCase):

@patch('airflow.contrib.hooks.winrm_hook.Protocol')
def test_get_conn_exists(self, mock_protocol):
winrm_hook = WinRMHook()
winrm_hook.client = mock_protocol.return_value.open_shell.return_value

conn = winrm_hook.get_conn()

self.assertEqual(conn, winrm_hook.client)

def test_get_conn_missing_remote_host(self):
with self.assertRaises(AirflowException):
WinRMHook().get_conn()

@patch('airflow.contrib.hooks.winrm_hook.Protocol')
def test_get_conn_error(self, mock_protocol):
mock_protocol.side_effect = Exception('Error')

with self.assertRaises(AirflowException):
WinRMHook(remote_host='host').get_conn()

@patch('airflow.contrib.hooks.winrm_hook.Protocol')
@patch('airflow.contrib.hooks.winrm_hook.WinRMHook.get_connection',
return_value=Connection(
login='username',
password='password',
host='remote_host',
extra="""{
"endpoint": "endpoint",
"remote_port": 123,
"transport": "transport",
"service": "service",
"keytab": "keytab",
"ca_trust_path": "ca_trust_path",
"cert_pem": "cert_pem",
"cert_key_pem": "cert_key_pem",
"server_cert_validation": "server_cert_validation",
"kerberos_delegation": "true",
"read_timeout_sec": 123,
"operation_timeout_sec": 123,
"kerberos_hostname_override": "kerberos_hostname_override",
"message_encryption": "message_encryption",
"credssp_disable_tlsv1_2": "true",
"send_cbt": "false"
}"""
))
def test_get_conn_from_connection(self, mock_get_connection, mock_protocol):
connection = mock_get_connection.return_value
winrm_hook = WinRMHook(ssh_conn_id='conn_id')

winrm_hook.get_conn()

mock_get_connection.assert_called_once_with(winrm_hook.ssh_conn_id)
mock_protocol.assert_called_once_with(
endpoint=str(connection.extra_dejson['endpoint']),
transport=str(connection.extra_dejson['transport']),
username=connection.login,
password=connection.password,
service=str(connection.extra_dejson['service']),
keytab=str(connection.extra_dejson['keytab']),
ca_trust_path=str(connection.extra_dejson['ca_trust_path']),
cert_pem=str(connection.extra_dejson['cert_pem']),
cert_key_pem=str(connection.extra_dejson['cert_key_pem']),
server_cert_validation=str(connection.extra_dejson['server_cert_validation']),
kerberos_delegation=str(connection.extra_dejson['kerberos_delegation']).lower() == 'true',
read_timeout_sec=int(connection.extra_dejson['read_timeout_sec']),
operation_timeout_sec=int(connection.extra_dejson['operation_timeout_sec']),
kerberos_hostname_override=str(connection.extra_dejson['kerberos_hostname_override']),
message_encryption=str(connection.extra_dejson['message_encryption']),
credssp_disable_tlsv1_2=str(connection.extra_dejson['credssp_disable_tlsv1_2']).lower() == 'true',
send_cbt=str(connection.extra_dejson['send_cbt']).lower() == 'true'
)

@patch('airflow.contrib.hooks.winrm_hook.getpass.getuser', return_value='user')
@patch('airflow.contrib.hooks.winrm_hook.Protocol')
def test_get_conn_no_username(self, mock_protocol, mock_getuser):
winrm_hook = WinRMHook(remote_host='host', password='password')

winrm_hook.get_conn()

self.assertEqual(mock_getuser.return_value, winrm_hook.username)

@patch('airflow.contrib.hooks.winrm_hook.Protocol')
def test_get_conn_no_endpoint(self, mock_protocol):
winrm_hook = WinRMHook(remote_host='host', password='password')

winrm_hook.get_conn()

self.assertEqual('http://{0}:{1}/wsman'.format(winrm_hook.remote_host, winrm_hook.remote_port),
winrm_hook.endpoint)

0 comments on commit 92638e8

Please sign in to comment.