Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix incompatibility with Twisted < 21. #10713

Merged
merged 3 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10713.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ files =
tests/test_utils,
tests/handlers/test_password_providers.py,
tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py,
tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_itertools.py,
Expand Down
37 changes: 31 additions & 6 deletions synapse/handlers/send_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

import email.utils
import inspect
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from io import BytesIO
from typing import TYPE_CHECKING, Optional

from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReactorTCP
from twisted.mail.smtp import ESMTPSenderFactory
from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory

from synapse.logging.context import make_deferred_yieldable

Expand All @@ -31,6 +32,21 @@
logger = logging.getLogger(__name__)


class _NoTLSESMTPSender(ESMTPSender):
"""Extend ESMTPSender to disable TLS

Unfortunatlely ESMTPSender doesn't give an easy way to disable TLS, so we override
its internal method which it uses to generate a context factory.

As of Twisted 21.2, one alternative is to set the `hostname` param of
ESMTPSenderFactory to `None`, so if in future we drop support for earlier versions,
that is a possibility.
"""

def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]:
return None


async def _sendmail(
reactor: IReactorTCP,
smtphost: str,
Expand All @@ -42,7 +58,7 @@ async def _sendmail(
password: Optional[bytes] = None,
require_auth: bool = False,
require_tls: bool = False,
tls_hostname: Optional[str] = None,
enable_tls: bool = True,
) -> None:
"""A simple wrapper around ESMTPSenderFactory, to allow substitution in tests

Expand All @@ -57,12 +73,19 @@ async def _sendmail(
password: password to give when authenticating
require_auth: if auth is not offered, fail the request
require_tls: if TLS is not offered, fail the reqest
tls_hostname: TLS hostname to check for. None to disable TLS.
enable_tls: True to enable TLS. If this is False and require_tls is True,
the request will fail.
"""
msg = BytesIO(msg_bytes)

d: "Deferred[object]" = Deferred()

# Twisted 21.2 introduced a 'hostname' parameter to ESMTPSenderFactory, which we
# need to set to enable TLS.
kwargs = {}
sig = inspect.signature(ESMTPSenderFactory)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead gate these changes behind a twisted.__version__ check? That way we don't need to worry about Twisted changing their internals in future versions, and eventually we can easily just drop the version check entirely when we bump our minimum dependencies?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, ptal?

if "hostname" in sig.parameters:
kwargs["hostname"] = smtphost
factory = ESMTPSenderFactory(
username,
password,
Expand All @@ -73,8 +96,10 @@ async def _sendmail(
heloFallback=True,
requireAuthentication=require_auth,
requireTransportSecurity=require_tls,
hostname=tls_hostname,
**kwargs,
)
if not enable_tls:
factory.protocol = _NoTLSESMTPSender

# the IReactorTCP interface claims host has to be a bytes, which seems to be wrong
reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type]
Expand Down Expand Up @@ -154,5 +179,5 @@ async def send_email(
password=self._smtp_pass,
require_auth=self._smtp_user is not None,
require_tls=self._require_transport_security,
tls_hostname=self._smtp_host if self._enable_tls else None,
enable_tls=self._enable_tls,
)
112 changes: 112 additions & 0 deletions tests/handlers/test_send_email.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Tuple

from zope.interface import implementer

from twisted.internet import defer
from twisted.internet.address import IPv4Address
from twisted.internet.defer import ensureDeferred
from twisted.mail import interfaces, smtp

from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase


@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
def __init__(self):
# (recipient, message) tuples
self.messages: List[Tuple[smtp.Address, bytes]] = []

def receivedHeader(self, helo, origin, recipients):
return None

def validateFrom(self, helo, origin):
return origin

def record_message(self, recipient: smtp.Address, message: bytes):
self.messages.append((recipient, message))

def validateTo(self, user: smtp.User):
return lambda: _DummyMessage(self, user)


@implementer(interfaces.IMessageSMTP)
class _DummyMessage:
"""IMessageSMTP implementation which saves the message delivered to it
to the _DummyMessageDelivery object.
"""

def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
self._delivery = delivery
self._user = user
self._buffer: List[bytes] = []

def lineReceived(self, line):
self._buffer.append(line)

def eomReceived(self):
message = b"\n".join(self._buffer) + b"\n"
self._delivery.record_message(self._user.dest, message)
return defer.succeed(b"saved")

def connectionLost(self):
pass


class SendEmailHandlerTestCase(HomeserverTestCase):
def test_send_email(self):
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
h.send_email(
"[email protected]", "test subject", "Tests", "HTML content", "Text content"
)
)
# there should be an attempt to connect to localhost:25
self.assertEqual(len(self.reactor.tcpClients), 1)
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
0
]
self.assertEqual(host, "localhost")
self.assertEqual(port, 25)

# wire it up to an SMTP server
message_delivery = _DummyMessageDelivery()
server_protocol = smtp.ESMTP()
server_protocol.delivery = message_delivery
# make sure that the server uses the test reactor to set timeouts
server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]

client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
server_protocol.makeConnection(
FakeTransport(
client_protocol,
self.reactor,
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
)
)

# the message should now get delivered
self.get_success(d, by=0.1)

# check it arrived
self.assertEqual(len(message_delivery.messages), 1)
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "[email protected]")
self.assertIn(b"Subject: test subject", msg)
15 changes: 12 additions & 3 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

from twisted.internet import address, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
IHostnameResolver,
IProtocol,
IPullProducer,
Expand Down Expand Up @@ -511,6 +512,9 @@ class FakeTransport:
will get called back for connectionLost() notifications etc.
"""

_peer_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returend by getPeer"""

disconnecting = False
disconnected = False
connected = True
Expand All @@ -519,7 +523,7 @@ class FakeTransport:
autoflush = attr.ib(default=True)

def getPeer(self):
return None
return self._peer_address

def getHost(self):
return None
Expand Down Expand Up @@ -572,7 +576,12 @@ def registerProducer(self, producer, streaming):
self.producerStreaming = streaming

def _produce():
d = self.producer.resumeProducing()
if not self.producer:
# we've been unregistered
return
# some implementations of IProducer (for example, FileSender)
# don't return a deferred.
d = maybeDeferred(self.producer.resumeProducing)
d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))

if not streaming:
Expand Down