Skip to content

Commit

Permalink
Support ssl_context attributes which expose alpn (#756)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Friedman <[email protected]>
  • Loading branch information
RFRIEDM-Trimble authored Aug 29, 2022
1 parent cd88919 commit ae4a5b1
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 13 deletions.
1 change: 1 addition & 0 deletions tavern/_core/schema/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def verify_jsonschema(to_verify, schema):
try:
validator.validate(to_verify)
except jsonschema.ValidationError as e:
print(e)
real_context = []

# ignore these strings because they're red herrings
Expand Down
135 changes: 123 additions & 12 deletions tavern/_plugins/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ def __init__(self, topic, subscribed=False):
self.subscribed = subscribed


def check_file_exists(key, filename):
try:
with open(filename, "r", encoding="utf-8"):
pass
except IOError as e:
raise exceptions.MQTTTLSError(
"Couldn't load '{}' from '{}'".format(key, filename)
) from e


def _handle_tls_args(tls_args):
"""Make sure TLS options are valid"""

Expand All @@ -55,21 +65,13 @@ def _handle_tls_args(tls_args):
"If specifying a TLS keyfile, a certfile also needs to be specified"
)

def check_file_exists(key):
# could be moved to schema validation stage
for key in ["certfile", "keyfile"]:
try:
with open(tls_args[key], "r", encoding="utf-8"):
pass
except IOError as e:
raise exceptions.MQTTTLSError(
"Couldn't load '{}' from '{}'".format(key, tls_args[key])
) from e
check_file_exists(key, tls_args[key])
except KeyError:
pass

# could be moved to schema validation stage
check_file_exists("certfile")
check_file_exists("keyfile")

# This shouldn't raise an AttributeError because it's enumerated
try:
tls_args["cert_reqs"] = getattr(ssl, tls_args["cert_reqs"])
Expand All @@ -91,6 +93,43 @@ def check_file_exists(key):
return tls_args


def _handle_ssl_context_args(ssl_context_args):
"""Make sure SSL Context options are valid"""
if not ssl_context_args:
return None

if "keyfile" in ssl_context_args and "certfile" not in ssl_context_args:
raise exceptions.MQTTTLSError(
"If specifying a TLS keyfile, a certfile also needs to be specified"
)

# could be moved to schema validation stage
check_file_exists("certfile", ssl_context_args["certfile"])
check_file_exists("keyfile", ssl_context_args["keyfile"])
if "cafile" in ssl_context_args:
check_file_exists("cafile", ssl_context_args["cafile"])

# This shouldn't raise an AttributeError because it's enumerated
try:
ssl_context_args["cert_reqs"] = getattr(ssl, ssl_context_args["cert_reqs"])
except KeyError:
pass

try:
ssl_context_args["tls_version"] = getattr(ssl, ssl_context_args["tls_version"])
except AttributeError as e:
raise exceptions.MQTTTLSError(
"Error getting TLS version from "
"ssl module - ssl module had no attribute '{}'. Check the "
"documentation for the version of python you're using to see "
"if this a valid option.".format(ssl_context_args["tls_version"])
) from e
except KeyError:
pass

return ssl_context_args


class MQTTClient:
# pylint: disable=too-many-instance-attributes

Expand All @@ -116,6 +155,15 @@ def __init__(self, **kwargs):
"ciphers",
},
"auth": {"username", "password"},
"ssl_context": {
"ca_certs",
"certfile",
"keyfile",
"password",
"tls_version",
"ciphers",
"alpn_protocols",
},
}

logger.debug("Initialising MQTT client with %s", kwargs)
Expand All @@ -139,12 +187,24 @@ def __init__(self, **kwargs):

self._connect_timeout = self._connect_args.pop("timeout", 3)

# If there is any tls kwarg (including 'enable'), enable tls
# If there is any tls or ssl_context kwarg, configure tls encryption
file_tls_args = kwargs.pop("tls", {})
file_ssl_context_args = kwargs.pop("ssl_context", {})

if file_tls_args and file_ssl_context_args:
msg = (
"'tls' and 'ssl_context' are both specified but are mutually exclusive"
)
raise exceptions.MQTTTLSError(msg)

check_expected_keys(expected_blocks["tls"], file_tls_args)
self._tls_args = _handle_tls_args(file_tls_args)
logger.debug("TLS is %s", "enabled" if self._tls_args else "disabled")

# If there is any SSL kwarg, enable tls through the SSL context
check_expected_keys(expected_blocks["ssl_context"], file_ssl_context_args)
self._ssl_context_args = _handle_ssl_context_args(file_ssl_context_args)

logger.debug("Paho client args: %s", self._client_args)
self._client = paho.Client(**self._client_args)
self._client.enable_logger()
Expand All @@ -171,6 +231,57 @@ def __init__(self, **kwargs):
"Unexpected SSL error enabling TLS"
) from e

if self._ssl_context_args:
# Create SSLContext object
tls_version = self._ssl_context_args.get("tls_version")
if tls_version is None:
# If the python version supports it, use highest TLS version automatically
if hasattr(ssl, "PROTOCOL_TLS_CLIENT"):
tls_version = ssl.PROTOCOL_TLS_CLIENT
elif hasattr(ssl, "PROTOCOL_TLS"):
tls_version = ssl.PROTOCOL_TLS
else:
tls_version = ssl.PROTOCOL_TLSv1_2
ca_certs = self._ssl_context_args.get("cert_reqs")
context = ssl.create_default_context(cafile=ca_certs)

certfile = self._ssl_context_args.get("certfile")
keyfile = self._ssl_context_args.get("keyfile")
password = self._ssl_context_args.get("password")

# Configure context
if certfile is not None:
context.load_cert_chain(certfile, keyfile, password)

cert_reqs = self._ssl_context_args.get("cert_reqs")
if cert_reqs == ssl.CERT_NONE and hasattr(context, "check_hostname"):
context.check_hostname = False

context.verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs

if ca_certs is not None:
context.load_verify_locations(ca_certs)
else:
context.load_default_certs()

ciphers = self._ssl_context_args.get("cipthers")
if ciphers is not None:
context.set_ciphers(ciphers)

alpn_protocols = self._ssl_context_args.get("alpn_protocols")
if alpn_protocols is not None:
context.set_alpn_protocols(alpn_protocols)

self._client.tls_set_context(context)

if cert_reqs != ssl.CERT_NONE:
# Default to secure, sets context.check_hostname attribute
# if available
self._client.tls_insecure_set(False)
else:
# But with ssl.CERT_NONE, we can not check_hostname
self._client.tls_insecure_set(True)

# Arbitrary number, could just be 1 and only accept 1 message per stages
# but we might want to raise an error if more than 1 message is received
# during a test stage.
Expand Down
47 changes: 46 additions & 1 deletion tavern/_plugins/mqtt/jsonschema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ properties:
description: How long to wait for connection before giving up

tls:
description: Custom options to control secure connection
description: Basic custom options to control secure connection

type: object
additionalProperties: false
Expand Down Expand Up @@ -107,6 +107,51 @@ properties:
type: string
description: Allowed ciphers to use with connection

ssl_context:
description: Advanced custom options to control secure connection using SSLContext

type: object
additionalProperties: false

properties:
ca_certs:
type: string
description: Path to CA cert bundle

certfile:
type: string
description: Path to certificate for server

keyfile:
type: string
description: Path to private key for client

password:
type: string
description: Password for keyfile

cert_reqs:
type: string
description: Controls connection with cert
enum:
- CERT_NONE
- CERT_OPTIONAL
- CERT_REQUIRED

tls_version:
type: string
description: TLS version to use

ciphers:
type: string
description: Allowed ciphers to use with connection

alpn_protocols:
type: array
description: |
Which protocols the socket should advertise during the SSL/TLS handshake.
See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_alpn_protocols
auth:
description: Username and password for basic authorisation

Expand Down
45 changes: 45 additions & 0 deletions tavern/_plugins/mqtt/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,51 @@ initialisation:
required: false
type: str

ssl_context:
required: false
type: map
mapping:
ca_certs:
required: false
type: str

certfile:
required: false
type: str

keyfile:
required: false
type: str

password:
required: false
type: str
# This is the password for the keyfile, and is only needed if the keyfile is password encrypted
# If not supplied, but the keyfile is password protect, the ssl module will prompt for a password in terminal

cert_reqs:
required: false
type: str
enum:
- CERT_NONE
- CERT_OPTIONAL
- CERT_REQUIRED

tls_version:
required: false
type: str
# This could be an enum but there's lots of them, and which ones are
# actually valid changes based on which version of python you're
# using. Just let any ssl errors propagate through

ciphers:
required: false
type: str

alpn_protocols:
required: false
type: array

auth:
required: false
type: map
Expand Down

0 comments on commit ae4a5b1

Please sign in to comment.