diff --git a/aioftp/client.py b/aioftp/client.py index 0c3ec85..c23ea6e 100644 --- a/aioftp/client.py +++ b/aioftp/client.py @@ -6,6 +6,8 @@ import logging import pathlib import re +import ssl +import sys from functools import partial from . import errors, pathio @@ -38,6 +40,7 @@ "DataConnectionThrottleStreamIO", "Code", ) +IS_PY311_PLUS = sys.version_info >= (3, 11, 0) logger = logging.getLogger(__name__) @@ -118,6 +121,7 @@ def __init__(self, *, parse_list_line_custom=None, parse_list_line_custom_first=True, passive_commands=("epsv", "pasv"), + explicit_tls=False, **siosocks_asyncio_kwargs): self.socket_timeout = socket_timeout self.connection_timeout = connection_timeout @@ -133,8 +137,10 @@ def __init__(self, *, self.parse_list_line_custom = parse_list_line_custom self.parse_list_line_custom_first = parse_list_line_custom_first self._passive_commands = passive_commands - self._open_connection = partial(open_connection, ssl=self.ssl, + self._open_connection = partial(open_connection, ssl=ssl if not explicit_tls else None, **siosocks_asyncio_kwargs) + self.explicit_tls = explicit_tls + self.logged_in = False async def connect(self, host, port=DEFAULT_PORT): self.server_host = host @@ -628,6 +634,47 @@ async def connect(self, host, port=DEFAULT_PORT): code, info = await self.command(None, "220", "120") return info + async def _send_tls_protection_commands(self) -> None: + """ + :py:func:`asyncio.coroutine` + + Sends the PBSZ and PROT commands as required for TLS connection. + """ + await self.command("PBSZ 0", "200") + await self.command("PROT P", "200") + + async def upgrade_to_tls(self) -> None: + """ + :py:func:`asyncio.coroutine` + + Attempts to upgrade the connection to TLS (explicit TLS). + Downgrading via the CCC or REIN commands is not supported. You may + call this command at any point during the connection. Both the command + and data channels will be encrypted after using this command. + + asyncio.StreamWriter.start_tls() was added in 3.11. Using this function + with an unsupported Python version will raise a RuntimeError. + """ + if not IS_PY311_PLUS: + raise RuntimeError("Python version 3.11.0 is required to upgrade a connection to TLS") + + if self.stream.writer.get_extra_info("ssl_object"): + return + + self.explicit_tls = True + + await self.command("AUTH TLS", "234") + + if not isinstance(self.ssl, ssl.SSLContext): + self.ssl = ssl.create_default_context() + try: + await self.stream.start_tls(sslcontext=self.ssl, server_hostname=self.server_host) + except ssl.SSLError as e: + raise errors.TLSError("Unable to upgrade connection to TLS") from e + + if self.logged_in: + await self._send_tls_protection_commands() + async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD, account=DEFAULT_ACCOUNT): """ @@ -646,6 +693,9 @@ async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD, :raises aioftp.StatusCodeError: if unknown code received """ + if self.explicit_tls: + await self.upgrade_to_tls() + code, info = await self.command("USER " + user, ("230", "33x")) while code.matches("33x"): censor_after = None @@ -658,6 +708,8 @@ async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD, raise errors.StatusCodeError("33x", code, info) code, info = await self.command(cmd, ("230", "33x"), censor_after=censor_after) + self.logged_in = True + await self._send_tls_protection_commands() async def get_current_directory(self): """ @@ -1167,6 +1219,11 @@ async def get_stream(self, *command_args, conn_type="I", offset=0): throttles={"_": self.throttle}, timeout=self.socket_timeout, ) + if self.explicit_tls: + try: + await stream.start_tls(sslcontext=self.ssl, server_hostname=self.server_host) + except ssl.SSLError as e: + raise errors.TLSError("Unable to upgrade data connection to TLS") from e return stream async def abort(self, *, wait=True): diff --git a/aioftp/common.py b/aioftp/common.py index 024d5db..69b6666 100644 --- a/aioftp/common.py +++ b/aioftp/common.py @@ -4,6 +4,8 @@ import functools import locale import threading +import ssl +import sys from contextlib import contextmanager __all__ = ( @@ -25,6 +27,7 @@ "DEFAULT_ACCOUNT", "setlocale", ) +IS_PY311_PLUS = sys.version_info >= (3, 11, 0) END_OF_LINE = "\r\n" @@ -317,6 +320,16 @@ def close(self): """ self.writer.close() + async def start_tls(self, sslcontext: ssl.SSLContext, server_hostname: str) -> None: + """ + Upgrades the connection to TLS + """ + if not IS_PY311_PLUS: + raise RuntimeError("Python version 3.11.0 is required to upgrade a connection to TLS") + await self.writer.start_tls(sslcontext=sslcontext, + server_hostname=server_hostname, + ssl_handshake_timeout=self.write_timeout) + class Throttle: """ diff --git a/aioftp/errors.py b/aioftp/errors.py index 2955cc4..c4ccd6d 100644 --- a/aioftp/errors.py +++ b/aioftp/errors.py @@ -7,6 +7,7 @@ "PathIsNotAbsolute", "PathIOError", "NoAvailablePort", + "TLSError", ) @@ -79,3 +80,9 @@ class NoAvailablePort(AIOFTPException, OSError): """ Raised when there is no available data port """ + + +class TLSError(AIOFTPException): + """ + Any TLS related errors + """