Skip to content

Commit

Permalink
Initial commit for explicit TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
sammichaels committed Nov 10, 2023
1 parent 80b8ebb commit 86a6a8c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
59 changes: 58 additions & 1 deletion aioftp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import pathlib
import re
import ssl
import sys
from functools import partial

from . import errors, pathio
Expand Down Expand Up @@ -38,6 +40,7 @@
"DataConnectionThrottleStreamIO",
"Code",
)
IS_PY311_PLUS = sys.version_info >= (3, 11, 0)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -118,6 +121,7 @@ def __init__(self, *,
parse_list_line_custom=None,
parse_list_line_custom_first=True,
passive_commands=("epsv", "pasv"),
explicit_tls=False,
**siosocks_asyncio_kwargs):
self.socket_timeout = socket_timeout
self.connection_timeout = connection_timeout
Expand All @@ -133,8 +137,10 @@ def __init__(self, *,
self.parse_list_line_custom = parse_list_line_custom
self.parse_list_line_custom_first = parse_list_line_custom_first
self._passive_commands = passive_commands
self._open_connection = partial(open_connection, ssl=self.ssl,
self._open_connection = partial(open_connection, ssl=ssl if not explicit_tls else None,
**siosocks_asyncio_kwargs)
self.explicit_tls = explicit_tls
self.logged_in = False

async def connect(self, host, port=DEFAULT_PORT):
self.server_host = host
Expand Down Expand Up @@ -628,6 +634,47 @@ async def connect(self, host, port=DEFAULT_PORT):
code, info = await self.command(None, "220", "120")
return info

async def _send_tls_protection_commands(self) -> None:
"""
:py:func:`asyncio.coroutine`
Sends the PBSZ and PROT commands as required for TLS connection.
"""
await self.command("PBSZ 0", "200")
await self.command("PROT P", "200")

async def upgrade_to_tls(self) -> None:
"""
:py:func:`asyncio.coroutine`
Attempts to upgrade the connection to TLS (explicit TLS).
Downgrading via the CCC or REIN commands is not supported. You may
call this command at any point during the connection. Both the command
and data channels will be encrypted after using this command.
asyncio.StreamWriter.start_tls() was added in 3.11. Using this function
with an unsupported Python version will raise a RuntimeError.
"""
if not IS_PY311_PLUS:
raise RuntimeError("Python version 3.11.0 is required to upgrade a connection to TLS")

if self.stream.writer.get_extra_info("ssl_object"):
return

self.explicit_tls = True

await self.command("AUTH TLS", "234")

if not isinstance(self.ssl, ssl.SSLContext):
self.ssl = ssl.create_default_context()
try:
await self.stream.start_tls(sslcontext=self.ssl, server_hostname=self.server_host)
except ssl.SSLError as e:
raise errors.TLSError("Unable to upgrade connection to TLS") from e

if self.logged_in:
await self._send_tls_protection_commands()

async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD,
account=DEFAULT_ACCOUNT):
"""
Expand All @@ -646,6 +693,9 @@ async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD,
:raises aioftp.StatusCodeError: if unknown code received
"""
if self.explicit_tls:
await self.upgrade_to_tls()

code, info = await self.command("USER " + user, ("230", "33x"))
while code.matches("33x"):
censor_after = None
Expand All @@ -658,6 +708,8 @@ async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD,
raise errors.StatusCodeError("33x", code, info)
code, info = await self.command(cmd, ("230", "33x"),
censor_after=censor_after)
self.logged_in = True
await self._send_tls_protection_commands()

async def get_current_directory(self):
"""
Expand Down Expand Up @@ -1167,6 +1219,11 @@ async def get_stream(self, *command_args, conn_type="I", offset=0):
throttles={"_": self.throttle},
timeout=self.socket_timeout,
)
if self.explicit_tls:
try:
await stream.start_tls(sslcontext=self.ssl, server_hostname=self.server_host)
except ssl.SSLError as e:
raise errors.TLSError("Unable to upgrade data connection to TLS") from e
return stream

async def abort(self, *, wait=True):
Expand Down
13 changes: 13 additions & 0 deletions aioftp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import functools
import locale
import threading
import ssl
import sys
from contextlib import contextmanager

__all__ = (
Expand All @@ -25,6 +27,7 @@
"DEFAULT_ACCOUNT",
"setlocale",
)
IS_PY311_PLUS = sys.version_info >= (3, 11, 0)


END_OF_LINE = "\r\n"
Expand Down Expand Up @@ -317,6 +320,16 @@ def close(self):
"""
self.writer.close()

async def start_tls(self, sslcontext: ssl.SSLContext, server_hostname: str) -> None:
"""
Upgrades the connection to TLS
"""
if not IS_PY311_PLUS:
raise RuntimeError("Python version 3.11.0 is required to upgrade a connection to TLS")
await self.writer.start_tls(sslcontext=sslcontext,
server_hostname=server_hostname,
ssl_handshake_timeout=self.write_timeout)


class Throttle:
"""
Expand Down
7 changes: 7 additions & 0 deletions aioftp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"PathIsNotAbsolute",
"PathIOError",
"NoAvailablePort",
"TLSError",
)


Expand Down Expand Up @@ -79,3 +80,9 @@ class NoAvailablePort(AIOFTPException, OSError):
"""
Raised when there is no available data port
"""


class TLSError(AIOFTPException):
"""
Any TLS related errors
"""

0 comments on commit 86a6a8c

Please sign in to comment.