diff --git a/tavern/_core/schema/jsonschema.py b/tavern/_core/schema/jsonschema.py index b8951ad8..59a3e060 100644 --- a/tavern/_core/schema/jsonschema.py +++ b/tavern/_core/schema/jsonschema.py @@ -122,6 +122,7 @@ def verify_jsonschema(to_verify, schema): try: validator.validate(to_verify) except jsonschema.ValidationError as e: + print(e) real_context = [] # ignore these strings because they're red herrings diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index c7623913..9d18f103 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -39,6 +39,16 @@ def __init__(self, topic, subscribed=False): self.subscribed = subscribed +def check_file_exists(key, filename): + try: + with open(filename, "r", encoding="utf-8"): + pass + except IOError as e: + raise exceptions.MQTTTLSError( + "Couldn't load '{}' from '{}'".format(key, filename) + ) from e + + def _handle_tls_args(tls_args): """Make sure TLS options are valid""" @@ -55,21 +65,13 @@ def _handle_tls_args(tls_args): "If specifying a TLS keyfile, a certfile also needs to be specified" ) - def check_file_exists(key): + # could be moved to schema validation stage + for key in ["certfile", "keyfile"]: try: - with open(tls_args[key], "r", encoding="utf-8"): - pass - except IOError as e: - raise exceptions.MQTTTLSError( - "Couldn't load '{}' from '{}'".format(key, tls_args[key]) - ) from e + check_file_exists(key, tls_args[key]) except KeyError: pass - # could be moved to schema validation stage - check_file_exists("certfile") - check_file_exists("keyfile") - # This shouldn't raise an AttributeError because it's enumerated try: tls_args["cert_reqs"] = getattr(ssl, tls_args["cert_reqs"]) @@ -91,6 +93,43 @@ def check_file_exists(key): return tls_args +def _handle_ssl_context_args(ssl_context_args): + """Make sure SSL Context options are valid""" + if not ssl_context_args: + return None + + if "keyfile" in ssl_context_args and "certfile" not in ssl_context_args: + raise exceptions.MQTTTLSError( + "If specifying a TLS keyfile, a certfile also needs to be specified" + ) + + # could be moved to schema validation stage + check_file_exists("certfile", ssl_context_args["certfile"]) + check_file_exists("keyfile", ssl_context_args["keyfile"]) + if "cafile" in ssl_context_args: + check_file_exists("cafile", ssl_context_args["cafile"]) + + # This shouldn't raise an AttributeError because it's enumerated + try: + ssl_context_args["cert_reqs"] = getattr(ssl, ssl_context_args["cert_reqs"]) + except KeyError: + pass + + try: + ssl_context_args["tls_version"] = getattr(ssl, ssl_context_args["tls_version"]) + except AttributeError as e: + raise exceptions.MQTTTLSError( + "Error getting TLS version from " + "ssl module - ssl module had no attribute '{}'. Check the " + "documentation for the version of python you're using to see " + "if this a valid option.".format(ssl_context_args["tls_version"]) + ) from e + except KeyError: + pass + + return ssl_context_args + + class MQTTClient: # pylint: disable=too-many-instance-attributes @@ -116,6 +155,15 @@ def __init__(self, **kwargs): "ciphers", }, "auth": {"username", "password"}, + "ssl_context": { + "ca_certs", + "certfile", + "keyfile", + "password", + "tls_version", + "ciphers", + "alpn_protocols", + }, } logger.debug("Initialising MQTT client with %s", kwargs) @@ -139,12 +187,24 @@ def __init__(self, **kwargs): self._connect_timeout = self._connect_args.pop("timeout", 3) - # If there is any tls kwarg (including 'enable'), enable tls + # If there is any tls or ssl_context kwarg, configure tls encryption file_tls_args = kwargs.pop("tls", {}) + file_ssl_context_args = kwargs.pop("ssl_context", {}) + + if file_tls_args and file_ssl_context_args: + msg = ( + "'tls' and 'ssl_context' are both specified but are mutually exclusive" + ) + raise exceptions.MQTTTLSError(msg) + check_expected_keys(expected_blocks["tls"], file_tls_args) self._tls_args = _handle_tls_args(file_tls_args) logger.debug("TLS is %s", "enabled" if self._tls_args else "disabled") + # If there is any SSL kwarg, enable tls through the SSL context + check_expected_keys(expected_blocks["ssl_context"], file_ssl_context_args) + self._ssl_context_args = _handle_ssl_context_args(file_ssl_context_args) + logger.debug("Paho client args: %s", self._client_args) self._client = paho.Client(**self._client_args) self._client.enable_logger() @@ -171,6 +231,57 @@ def __init__(self, **kwargs): "Unexpected SSL error enabling TLS" ) from e + if self._ssl_context_args: + # Create SSLContext object + tls_version = self._ssl_context_args.get("tls_version") + if tls_version is None: + # If the python version supports it, use highest TLS version automatically + if hasattr(ssl, "PROTOCOL_TLS_CLIENT"): + tls_version = ssl.PROTOCOL_TLS_CLIENT + elif hasattr(ssl, "PROTOCOL_TLS"): + tls_version = ssl.PROTOCOL_TLS + else: + tls_version = ssl.PROTOCOL_TLSv1_2 + ca_certs = self._ssl_context_args.get("cert_reqs") + context = ssl.create_default_context(cafile=ca_certs) + + certfile = self._ssl_context_args.get("certfile") + keyfile = self._ssl_context_args.get("keyfile") + password = self._ssl_context_args.get("password") + + # Configure context + if certfile is not None: + context.load_cert_chain(certfile, keyfile, password) + + cert_reqs = self._ssl_context_args.get("cert_reqs") + if cert_reqs == ssl.CERT_NONE and hasattr(context, "check_hostname"): + context.check_hostname = False + + context.verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs + + if ca_certs is not None: + context.load_verify_locations(ca_certs) + else: + context.load_default_certs() + + ciphers = self._ssl_context_args.get("cipthers") + if ciphers is not None: + context.set_ciphers(ciphers) + + alpn_protocols = self._ssl_context_args.get("alpn_protocols") + if alpn_protocols is not None: + context.set_alpn_protocols(alpn_protocols) + + self._client.tls_set_context(context) + + if cert_reqs != ssl.CERT_NONE: + # Default to secure, sets context.check_hostname attribute + # if available + self._client.tls_insecure_set(False) + else: + # But with ssl.CERT_NONE, we can not check_hostname + self._client.tls_insecure_set(True) + # Arbitrary number, could just be 1 and only accept 1 message per stages # but we might want to raise an error if more than 1 message is received # during a test stage. diff --git a/tavern/_plugins/mqtt/jsonschema.yaml b/tavern/_plugins/mqtt/jsonschema.yaml index 733bad3b..6b3af52c 100644 --- a/tavern/_plugins/mqtt/jsonschema.yaml +++ b/tavern/_plugins/mqtt/jsonschema.yaml @@ -68,7 +68,7 @@ properties: description: How long to wait for connection before giving up tls: - description: Custom options to control secure connection + description: Basic custom options to control secure connection type: object additionalProperties: false @@ -107,6 +107,51 @@ properties: type: string description: Allowed ciphers to use with connection + ssl_context: + description: Advanced custom options to control secure connection using SSLContext + + type: object + additionalProperties: false + + properties: + ca_certs: + type: string + description: Path to CA cert bundle + + certfile: + type: string + description: Path to certificate for server + + keyfile: + type: string + description: Path to private key for client + + password: + type: string + description: Password for keyfile + + cert_reqs: + type: string + description: Controls connection with cert + enum: + - CERT_NONE + - CERT_OPTIONAL + - CERT_REQUIRED + + tls_version: + type: string + description: TLS version to use + + ciphers: + type: string + description: Allowed ciphers to use with connection + + alpn_protocols: + type: array + description: | + Which protocols the socket should advertise during the SSL/TLS handshake. + See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_alpn_protocols + auth: description: Username and password for basic authorisation diff --git a/tavern/_plugins/mqtt/schema.yaml b/tavern/_plugins/mqtt/schema.yaml index d4bb8ef5..88f52a67 100644 --- a/tavern/_plugins/mqtt/schema.yaml +++ b/tavern/_plugins/mqtt/schema.yaml @@ -82,6 +82,51 @@ initialisation: required: false type: str + ssl_context: + required: false + type: map + mapping: + ca_certs: + required: false + type: str + + certfile: + required: false + type: str + + keyfile: + required: false + type: str + + password: + required: false + type: str + # This is the password for the keyfile, and is only needed if the keyfile is password encrypted + # If not supplied, but the keyfile is password protect, the ssl module will prompt for a password in terminal + + cert_reqs: + required: false + type: str + enum: + - CERT_NONE + - CERT_OPTIONAL + - CERT_REQUIRED + + tls_version: + required: false + type: str + # This could be an enum but there's lots of them, and which ones are + # actually valid changes based on which version of python you're + # using. Just let any ssl errors propagate through + + ciphers: + required: false + type: str + + alpn_protocols: + required: false + type: array + auth: required: false type: map