diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index a7fcc564a992..34578393467e 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -93,13 +93,6 @@ def _get_state_groups_from_groups_txn( results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in @@ -110,31 +103,71 @@ def _get_state_groups_from_groups_txn( # against `state_groups_state` to fetch the latest state. # It assumes that previous state groups are always numerically # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. sql = """ - WITH RECURSIVE state(state_group) AS ( + WITH RECURSIVE sgs(state_group) AS ( VALUES(?::bigint) UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s + SELECT prev_state_group FROM state_group_edges e, sgs s WHERE s.state_group = e.state_group ) - SELECT DISTINCT ON (type, state_key) - type, state_key, event_id - FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) %s - ORDER BY type, state_key, state_group DESC + %s """ + overall_select_query_args: List[any] = [] + + # This is an optimization to create a select clause per-condition. This + # makes the query planner a lot smarter on what rows should pull out in the + # first place and we end up with something that takes 10x less time to get a + # result. + if not state_filter.include_others and not state_filter.is_full(): + select_clause_list: List[str] = [] + for etype, state_keys in state_filter.types.items(): + for state_key in state_keys: + overall_select_query_args.extend([etype, state_key]) + select_clause_list.append( + """ + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + INNER JOIN sgs USING (state_group) + WHERE (type = ? AND state_key = ?) + ORDER BY type, state_key, state_group DESC + """ + ) + + overall_select_clause = ( + "(" + (") UNION (".join(select_clause_list)) + ")" + ) + else: + where_clause, where_args = state_filter.make_sql_filter_clause() + + overall_select_query_args = where_args + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + overall_select_clause = f""" + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM sgs + ) {where_clause} + ORDER BY type, state_key, state_group DESC + """ + + logger.info("overall_select_clause=%s", overall_select_clause) + logger.info("overall_select_query_args=%s", overall_select_query_args) + for group in groups: args: List[Union[int, str]] = [group] - args.extend(where_args) + args.extend(overall_select_query_args) - txn.execute(sql % (where_clause,), args) + txn.execute(sql % (overall_select_clause,), args) for row in txn: typ, state_key, event_id = row key = (intern_string(typ), intern_string(state_key))