Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Faster joins: parse msc3706 fields in send_join response #12011

Merged
merged 2 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12011.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: parse msc3706 fields in send_join response.
4 changes: 4 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ def read_config(self, config: JsonDict, **kwargs):

# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

# experimental support for faster joins over federation (msc2775, msc3706)
# requires a target server with msc3706_enabled enabled.
self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
15 changes: 14 additions & 1 deletion synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -89,6 +89,12 @@ class SendJoinResult:
state: List[EventBase]
auth_chain: List[EventBase]

# True if 'state' elides non-critical membership events
partial_state: bool

# if 'partial_state' is set, a list of the servers in the room (otherwise empty)
servers_in_room: List[str]


class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"):
Expand Down Expand Up @@ -876,11 +882,18 @@ async def _execute(pdu: EventBase) -> None:
% (auth_chain_create_events,)
)

if response.partial_state and not response.servers_in_room:
raise InvalidResponseError(
"partial_state was set, but no servers were listed in the room"
)

return SendJoinResult(
event=event,
state=signed_state,
auth_chain=signed_auth,
origin=destination,
partial_state=response.partial_state,
servers_in_room=response.servers_in_room or [],
)

# MSC3083 defines additional error codes for room joins.
Expand Down
118 changes: 87 additions & 31 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -60,6 +60,7 @@ class TransportLayerClient:
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled

async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
Expand Down Expand Up @@ -336,10 +337,15 @@ async def send_join_v2(
content: JsonDict,
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
query_params: Dict[str, str] = {}
if self._faster_joins_enabled:
# lazy-load state on join
query_params["org.matrix.msc3706.partial_state"] = "true"

return await self.client.put_json(
destination=destination,
path=path,
args=query_params,
data=content,
parser=SendJoinParser(room_version, v1_api=False),
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
Expand Down Expand Up @@ -1271,6 +1277,12 @@ class SendJoinResponse:
# "event" is not included in the response.
event: Optional[EventBase] = None

# The room state is incomplete
partial_state: bool = False

# List of servers in the room
servers_in_room: Optional[List[str]] = None


@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
Expand All @@ -1297,6 +1309,32 @@ def _event_list_parser(
events.append(event)


@ijson.coroutine
def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
"""Helper function for use with `ijson.items_coro`

Parses the partial_state field in send_join responses
"""
while True:
val = yield
if not isinstance(val, bool):
raise TypeError("partial_state must be a boolean")
response.partial_state = val


@ijson.coroutine
def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
"""Helper function for use with `ijson.items_coro`

Parses the servers_in_room field in send_join responses
"""
while True:
val = yield
if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
raise TypeError("servers_in_room must be a list of strings")
response.servers_in_room = val


class SendJoinParser(ByteParser[SendJoinResponse]):
"""A parser for the response to `/send_join` requests.

Expand All @@ -1308,44 +1346,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json"

def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], [], {})
self._response = SendJoinResponse([], [], event_dict={})
self._room_version = room_version
self._coros = []

# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
prefix = "item." if v1_api else ""

self._coro_state = ijson.items_coro(
_event_list_parser(room_version, self._response.state),
prefix + "state.item",
use_float=True,
)
self._coro_auth = ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
use_float=True,
)
# TODO Remove the unstable prefix when servers have updated.
#
# By re-using the same event dictionary this will cause the parsing of
# org.matrix.msc3083.v2.event and event to stomp over each other.
# Generally this should be fine.
self._coro_unstable_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
use_float=True,
)
self._coro_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "event",
use_float=True,
)
self._coros = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
prefix + "state.item",
use_float=True,
),
ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
use_float=True,
),
# TODO Remove the unstable prefix when servers have updated.
#
# By re-using the same event dictionary this will cause the parsing of
# org.matrix.msc3083.v2.event and event to stomp over each other.
# Generally this should be fine.
ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
use_float=True,
),
ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "event",
use_float=True,
),
]

if not v1_api:
self._coros.append(
ijson.items_coro(
_partial_state_parser(self._response),
"org.matrix.msc3706.partial_state",
use_float="True",
)
)

self._coros.append(
ijson.items_coro(
_servers_in_room_parser(self._response),
"org.matrix.msc3706.servers_in_room",
use_float="True",
)
)

def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
self._coro_unstable_event.send(data)
self._coro_event.send(data)
for c in self._coros:
c.send(data)

return len(data)

Expand Down
3 changes: 2 additions & 1 deletion synapse/python_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
"cryptography>=3.4.7",
"ijson>=3.1",
# ijson 3.1.4 fixes a bug with "." in property names
"ijson>=3.1.4",
"matrix-common~=1.1.0",
]

Expand Down
32 changes: 32 additions & 0 deletions tests/federation/transport/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,35 @@ def test_two_writes(self) -> None:
self.assertEqual(len(parsed_response.state), 1, parsed_response)
self.assertEqual(parsed_response.event_dict, {}, parsed_response)
self.assertIsNone(parsed_response.event, parsed_response)
self.assertFalse(parsed_response.partial_state, parsed_response)
self.assertEqual(parsed_response.servers_in_room, None, parsed_response)

def test_partial_state(self) -> None:
"""Check that the partial_state flag is correctly parsed"""
parser = SendJoinParser(RoomVersions.V1, False)
response = {
"org.matrix.msc3706.partial_state": True,
}

serialised_response = json.dumps(response).encode()

# Send data to the parser
parser.write(serialised_response)

# Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish()
self.assertTrue(parsed_response.partial_state)

def test_servers_in_room(self) -> None:
"""Check that the servers_in_room field is correctly parsed"""
parser = SendJoinParser(RoomVersions.V1, False)
response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}

serialised_response = json.dumps(response).encode()

# Send data to the parser
parser.write(serialised_response)

# Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish()
self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])