Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix limit logic for EventsStream #7358

Merged
merged 6 commits into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7358.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
4 changes: 3 additions & 1 deletion synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(self, hs):
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]

self._position_linearizer = Linearizer("replication_position")
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)

# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
Expand Down
22 changes: 8 additions & 14 deletions synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,16 @@ async def _update_function(
limited = False
upper_limit = current_token

# next up is the state delta table

state_rows = await self._store.get_all_updated_current_state_deltas(
# next up is the state delta table.
(
state_rows,
upper_limit,
state_rows_limited,
) = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, target_row_count
) # type: List[Tuple]

# again, if we've hit the limit there, we'll need to limit the other sources
assert len(state_rows) < target_row_count
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0]
limited = True
)

# FIXME: is it a given that there is only one row per stream_id in the
# state_deltas table (so that we can be sure that we have got all of the
# rows for upper_limit)?
limited = limited or state_rows_limited

# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory

class HomeServer(object):
@property
Expand Down Expand Up @@ -121,3 +122,7 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass
64 changes: 60 additions & 4 deletions synapse/storage/data_stores/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import threading
from collections import namedtuple
from typing import List, Optional
from typing import List, Optional, Tuple

from canonicaljson import json
from constantly import NamedConstant, Names
Expand Down Expand Up @@ -1084,18 +1084,74 @@ def get_all_new_backfill_event_rows(txn):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)

def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
async def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream

Args:
from_token: The previous stream token. Updates from this stream id will
be excluded.

to_token: The current stream token (ie the upper limit). Updates up to this
stream id will be included (modulo the 'limit' param)

target_row_count: The number of rows to try to return. If more rows are
available, we will set 'limited' in the result. In the event of a large
batch, we may return more rows than this.
Returns:
A triplet `(updates, new_last_token, limited)`, where:
* `updates` is a list of database tuples.
* `new_last_token` is the new position in stream.
* `limited` is whether there are more updates to fetch.
"""

def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
txn.execute(sql, (from_token, to_token, target_row_count))
return txn.fetchall()

return self.db.runInteraction(
def get_deltas_for_stream_id_txn(txn, stream_id):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()

# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.

rows = await self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]

# if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count:
return rows, to_token, False

# we hit the limit, so reduce the upper limit so that we exclude the stream id
# of the last row in the result.
assert rows[-1][0] <= to_token
to_token = rows[-1][0] - 1

# search backwards through the list for the point to truncate
for idx in range(len(rows) - 1, 0, -1):
if rows[idx - 1][0] <= to_token:
return rows[:idx], to_token, True

# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
to_token += 1
rows = await self.db.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
return rows, to_token, True
41 changes: 24 additions & 17 deletions tests/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional

from mock import Mock
import logging
from typing import Any, Dict, List, Optional, Tuple

import attr

Expand All @@ -25,6 +24,7 @@

from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
Expand Down Expand Up @@ -65,9 +65,7 @@ def prepare(self, reactor, clock, hs):
# databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db

self.test_handler = Mock(
wraps=TestReplicationDataHandler(self.worker_hs.get_datastore())
)
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler

repl_handler = ReplicationCommandHandler(self.worker_hs)
Expand All @@ -78,6 +76,9 @@ def prepare(self, reactor, clock, hs):
self._client_transport = None
self._server_transport = None

def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())

def reconnect(self):
if self._client_transport:
self.client.close()
Expand Down Expand Up @@ -174,22 +175,28 @@ def assert_request_is_get_repl_stream_updates(
class TestReplicationDataHandler(ReplicationDataHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""

def __init__(self, hs):
super().__init__(hs)
self.streams = set()
self._received_rdata_rows = []
def __init__(self, store: BaseSlavedStore):
super().__init__(store)

# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]

# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]

def get_streams_to_replicate(self):
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
return self.stream_positions

async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self._received_rdata_rows.append((stream_name, token, r))
self.received_rdata_rows.append((stream_name, token, r))

if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token


@attr.s()
Expand Down Expand Up @@ -221,7 +228,7 @@ def __init__(self, reactor: IReactorTime):
super().__init__()
self.reactor = reactor

self._pull_to_push_producer = None
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]

def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
Expand Down
Loading