Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Faster joins: persist to database #12012

Merged
merged 8 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12012.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database.
9 changes: 9 additions & 0 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class EventContext:

As with _current_state_ids, this is a private attribute. It should be
accessed via get_prev_state_ids.

partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""

rejected: Union[bool, str] = False
Expand All @@ -113,12 +116,15 @@ class EventContext:
_current_state_ids: Optional[StateMap[str]] = None
_prev_state_ids: Optional[StateMap[str]] = None

partial_state: bool = False

@staticmethod
def with_state(
state_group: Optional[int],
state_group_before_event: Optional[int],
current_state_ids: Optional[StateMap[str]],
prev_state_ids: Optional[StateMap[str]],
partial_state: bool,
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext":
Expand All @@ -129,6 +135,7 @@ def with_state(
state_group_before_event=state_group_before_event,
prev_group=prev_group,
delta_ids=delta_ids,
partial_state=partial_state,
)

@staticmethod
Expand Down Expand Up @@ -170,6 +177,7 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state,
}

@staticmethod
Expand All @@ -196,6 +204,7 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
partial_state=input.get("partial_state", False),
)

app_service_id = input["app_service_id"]
Expand Down
11 changes: 10 additions & 1 deletion synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,17 @@ async def do_invite_join(
state_events=state,
)

if ret.partial_state:
await self.store.store_partial_state_room(room_id, ret.servers_in_room)

max_stream_id = await self._federation_event_handler.process_remote_join(
origin, room_id, auth_chain, state, event, room_version_obj
origin,
room_id,
auth_chain,
state,
event,
room_version_obj,
partial_state=ret.partial_state,
)

# We wait here until this instance has seen the events come down
Expand Down
13 changes: 11 additions & 2 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ async def process_remote_join(
state: List[EventBase],
event: EventBase,
room_version: RoomVersion,
partial_state: bool,
) -> int:
"""Persists the events returned by a send_join

Expand All @@ -412,6 +413,7 @@ async def process_remote_join(
event
room_version: The room version we expect this room to have, and
will raise if it doesn't match the version in the create event.
partial_state: True if the state omits non-critical membership events

Returns:
The stream ID after which all events have been persisted.
Expand Down Expand Up @@ -453,10 +455,14 @@ async def process_remote_join(
)

# and now persist the join event itself.
logger.info("Peristing join-via-remote %s", event)
logger.info(
"Peristing join-via-remote %s (partial_state: %s)", event, partial_state
)
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event, old_state=state
event,
old_state=state,
partial_state=partial_state,
)

context = await self._check_event_auth(origin, event, context)
Expand Down Expand Up @@ -698,6 +704,8 @@ async def _process_pulled_event(

try:
state = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
await self._process_received_pdu(
origin, event, state=state, backfilled=backfilled
)
Expand Down Expand Up @@ -1791,6 +1799,7 @@ async def _update_context_for_auth_events(
prev_state_ids=prev_state_ids,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
)

async def _run_push_actions_and_persist_event(
Expand Down
2 changes: 2 additions & 0 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,8 @@ async def create_new_client_event(
and full_state_ids_at_event
and builder.internal_metadata.is_historical()
):
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(full_state_ids_at_event)
context = await self.state.compute_event_context(event, old_state=old_state)
else:
Expand Down
31 changes: 29 additions & 2 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,10 @@ async def get_hosts_in_room_at_events(
return await self.store.get_joined_hosts(room_id, entry)

async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.

Expand All @@ -273,6 +276,8 @@ async def compute_event_context(
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
Returns:
The event context.
"""
Expand All @@ -295,8 +300,28 @@ async def compute_event_context(

else:
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")

# partial_state should not be set explicitly in this case:
# we work it out dynamically
assert not partial_state
squahtx marked this conversation as resolved.
Show resolved Hide resolved

# if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use
# their states - but I don't think that really matters; it just means we
# might redundantly recalculate the state for this event later.)
prev_event_ids = event.prev_event_ids()
incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids
)
if any(incomplete_prev_events.values()):
logger.debug(
"New/incoming event %s refers to prev_events %s with partial state",
event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v],
)
partial_state = True

logger.debug("calling resolve_state_groups from compute_event_context")
entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
Expand Down Expand Up @@ -342,6 +367,7 @@ async def compute_event_context(
prev_state_ids=state_ids_before_event,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
partial_state=partial_state,
)

#
Expand Down Expand Up @@ -373,6 +399,7 @@ async def compute_event_context(
prev_state_ids=state_ids_before_event,
prev_group=state_group_before_event,
delta_ids=delta_ids,
partial_state=partial_state,
)

@measure_func()
Expand Down
25 changes: 25 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,14 @@ def _store_event_state_mappings_txn(
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
# double-check that we don't have any events that claim to be outliers
# *and* have partial state (which is meaningless: we should have no
# state at all for an outlier)
if context.partial_state:
raise ValueError(
"Outlier event %s claims to have partial state", event.event_id
)

continue

# if the event was rejected, just give it the same state as its
Expand All @@ -2152,6 +2160,23 @@ def _store_event_state_mappings_txn(

state_groups[event.event_id] = context.state_group

# if we have partial state for these events, record the fact. (This happens
# here rather than in _store_event_txn because it also needs to happen when
# we de-outlier an event.)
Comment on lines +2163 to +2165
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unclear on why we need to do this again when de-outliering. Does de-outliering not imply that the event has already been through this code once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be another great question. I'll try and remember what's going on here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I've remembered what this is about.

Outliers have no state at all, so it's meaningless to say that they have "partial state". So, the first time they go through this code, we would not expect to set the partial-state flag.

When we de-outlier the event, it's possible that we might de-outlier it with only partial state, and in that case we need to add a row to partial_state_events when we de-outlier it, which is why I've put this code here.

I've added a robustness check to make sure it's true that outliers are never flagged with partial-state in 446fae6.

I've also got it on my todo list to make sure we end up with tests which include de-outliering events with partial state.

Copy link
Contributor

@squahtx squahtx Feb 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! So the outliers we persist will never be flagged with partial state because we always use EventContext.for_outlier() as the context*?

* with the exception of the outlier in RoomMemberMasterHandler._generate_local_out_of_band_leave for reasons that are beyond me right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • with the exception of the outlier in RoomMemberMasterHandler._generate_local_out_of_band_leave for reasons that are beyond me right now.

Yeah, that looks a bit bogus. It means we end up recording state at the leave event, but also flagging it as an outlier, which seems wrong. Either it shouldn't be an outlier, or we shouldn't record the state at it. I'm going to park this for now though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the record: I've ended up with a fix to this, at #12154, as it broke in the face of my fixes to #12074.

self.db_pool.simple_insert_many_txn(
txn,
table="partial_state_events",
keys=("room_id", "event_id"),
values=[
(
event.room_id,
event.event_id,
)
for event, ctx in events_and_contexts
if ctx.partial_state
],
)

self.db_pool.simple_upsert_many_txn(
txn,
table="event_to_state_groups",
Expand Down
28 changes: 28 additions & 0 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,3 +1953,31 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,
)

@cachedList("is_partial_state_event", list_name="event_ids")
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
"""Checks which of the given events have partial state"""
result = await self.db_pool.simple_select_many_batch(
table="partial_state_events",
column="event_id",
iterable=event_ids,
retcols=["event_id"],
desc="get_partial_state_events",
)
# convert the result to a dict, to make @cachedList work
partial = {r["event_id"] for r in result}
return {e_id: e_id in partial for e_id in event_ids}

@cached()
async def is_partial_state_event(self, event_id: str) -> bool:
"""Checks if the given event has partial state"""
result = await self.db_pool.simple_select_one_onecol(
table="partial_state_events",
keyvalues={"event_id": event_id},
retcol="1",
allow_none=True,
desc="is_partial_state_event",
)
return result is not None
37 changes: 37 additions & 0 deletions synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TYPE_CHECKING,
Any,
Awaitable,
Collection,
Dict,
List,
Optional,
Expand Down Expand Up @@ -1543,6 +1544,42 @@ async def upsert_room_on_join(
lock=False,
)

async def store_partial_state_room(
self,
room_id: str,
servers: Collection[str],
) -> None:
"""Mark the given room as containing events with partial state

Args:
room_id: the ID of the room
servers: other servers known to be in the room
"""
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
room_id,
servers,
)

@staticmethod
def _store_partial_state_room_txn(
txn: LoggingTransaction, room_id: str, servers: Collection[str]
) -> None:
DatabasePool.simple_insert_txn(
txn,
table="partial_state_rooms",
values={
"room_id": room_id,
},
)
DatabasePool.simple_insert_many_txn(
txn,
table="partial_state_rooms_servers",
keys=("room_id", "server_name"),
values=((room_id, s) for s in servers),
)

async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
) -> None:
Expand Down
41 changes: 41 additions & 0 deletions synapse/storage/schema/main/delta/68/04partial_state_rooms.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

-- rooms which we have done a partial-state-style join to
CREATE TABLE IF NOT EXISTS partial_state_rooms (
room_id TEXT PRIMARY KEY,
FOREIGN KEY(room_id) REFERENCES rooms(room_id)
);

-- a list of remote servers we believe are in the room
CREATE TABLE IF NOT EXISTS partial_state_rooms_servers (
room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
server_name TEXT NOT NULL,
UNIQUE(room_id, server_name)
);

-- a list of events with partial state. We can't store this in the `events` table
-- itself, because `events` is meant to be append-only.
CREATE TABLE IF NOT EXISTS partial_state_events (
-- the room_id is denormalised for efficient indexing (the canonical source is `events`)
room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
event_id TEXT NOT NULL REFERENCES events(event_id),
UNIQUE(event_id)
);

CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx
ON partial_state_events (room_id);


Loading