Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MSC3916 by adding _matrix/client/v1/media/download endpoint #17365

Merged
merged 13 commits into from
Jul 2, 2024
Merged
20 changes: 20 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,26 @@ def filter_user_id(user_id: str) -> bool:

return filtered_statuses, filtered_failures

async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)

async def download_media(
self,
destination: str,
Expand Down
25 changes: 23 additions & 2 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,6 @@ async def download_media_r0(
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
Expand Down Expand Up @@ -852,7 +851,6 @@ async def download_media_v3(
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
Expand All @@ -873,6 +871,29 @@ async def download_media_v3(
ip_address=ip_address,
)

async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
path = f"/_matrix/federation/v1/media/download/{media_id}"
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
return await self.client.federation_get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)


def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
Expand Down
178 changes: 178 additions & 0 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@
BlocklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
SimpleHttpClient,
_make_scheduler,
encode_query_args,
read_body_with_max_size,
read_multipart_response,
)
from synapse.http.connectproxyclient import BearerProxyCredentials
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
Expand Down Expand Up @@ -466,6 +468,12 @@ def __init__(

self._sleeper = AwakenableSleeper(self.reactor)

self._simple_http_client = SimpleHttpClient(
hs,
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
use_proxy=True,
)

def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online."""

Expand Down Expand Up @@ -1553,6 +1561,176 @@ async def get_file(
)
return length, headers

async def federation_get_file(
self,
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: str,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
"""GETs a file from a given homeserver over the federation /download endpoint
Args:
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
requester IP
ip_address: IP address of the requester
max_size: maximum allowable size in bytes of the file
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.

Returns:
Resolves to an (int, dict, bytes) tuple of
the file length, a dict of the response headers, and the file json

Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
SynapseError: If the requested file exceeds ratelimits
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
"""
request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args
)

# check for a minimum balance of 1MiB in ratelimiter before initiating request
send_req, _ = await download_ratelimiter.can_do_action(
requester=None, key=ip_address, n_actions=1048576, update=False
)

if not send_req:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)

response = await self._send_request(
request,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)

headers = dict(response.headers.getAllRawHeaders())

expected_size = response.length
# if we don't get an expected length then use the max length
if expected_size == UNKNOWN_LENGTH:
expected_size = max_size
logger.debug(
f"File size unknown, assuming file is max allowable size: {max_size}"
)

read_body, _ = await download_ratelimiter.can_do_action(
requester=None,
key=ip_address,
n_actions=expected_size,
)
if not read_body:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

# this should be a multipart/mixed response with the boundary string in the header
raw_content_type = headers.get(b"Content-Type")
assert raw_content_type is not None
content_type = raw_content_type[0].decode("UTF-8")
boundary = content_type.split("boundary=")[1]
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved

try:
# add a byte of headroom to max size as function errs at >=
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at read_multipart_response, it looks like it errs at > only?

https://github.com/H-Shay/hq_synapse/blob/5fcc81ade3e9afdab89c0da0aa77bf44304fd8de/synapse/http/client.py#L1237-L1240

Or do you mean buried in _MultipartParserProtocol's parent, protocol.Protocol?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually was referring to what happens in the _MultipartParserProtocol's dataRecieved function: https://github.com/H-Shay/hq_synapse/blob/5fcc81ade3e9afdab89c0da0aa77bf44304fd8de/synapse/http/client.py#L1109

So the check in read_multipart_response checks that the response length provided by the remote server is not greater than the max length, and in the dataRecieved function it checks that the actual data coming in over the wire doesn't exceed that (insurance against faulty/malicious response length data). Erring at >= is modeled after what happens in _ReadBodyWithMaxSizeProtocol:

if self.max_size is not None and self.length >= self.max_size:
, which was the model for protocol I added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I see. Could you just add a pointer to that function to make this comment clearer?

Suggested change
# add a byte of headroom to max size as function errs at >=
# add a byte of headroom to max size as
# `_MultipartParserProtocol.dataReceived` errs at >=

deferred = read_multipart_response(
response, output_stream, boundary, expected_size + 1
)
deferred.addTimeout(self.default_timeout_seconds, self.reactor)
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except ResponseFailed as e:
logger.warning(
"{%s} [%s] Failed to read response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
e,
)
raise

multipart_response = await make_deferred_yieldable(deferred)
if not multipart_response.url:
assert multipart_response.length is not None
length = multipart_response.length
headers[b"Content-Type"] = [multipart_response.content_type]
headers[b"Content-Disposition"] = [multipart_response.disposition]

# the response contained a redirect url to download the file from
else:
str_url = multipart_response.url.decode("utf-8")
logger.info(
"{%s} [%s] File download redirected, now downloading from: %s",
request.txn_id,
request.destination,
str_url,
)
length, headers, _, _ = await self._simple_http_client.get_file(
str_url, output_stream, expected_size
)

logger.info(
"{%s} [%s] Completed: %d %s [%d bytes] %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
length,
request.method,
request.uri.decode("ascii"),
)
return length, headers, multipart_response.json


def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
Expand Down
Loading