Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix limit logic for EventsStream #7358

Merged
merged 6 commits into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(self, hs):
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]

self._position_linearizer = Linearizer("replication_position")
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)

# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
Expand Down
28 changes: 21 additions & 7 deletions synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,30 @@ async def _update_function(
from_token, upper_limit, target_row_count
) # type: List[Tuple]

# again, if we've hit the limit there, we'll need to limit the other sources
assert len(state_rows) < target_row_count
assert len(state_rows) <= target_row_count

# there can be more than one row per stream_id in that table, so if we hit
# the limit there, we'll need to truncate the results so that we have a complete
# set of changes for all the stream IDs we include.
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0]
limited = True
upper_limit = state_rows[-1][0] - 1

# search for the point to truncate the list
for idx in range(len(state_rows) - 1, 0, -1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably write this as for idx, row in reversed(enumerate(state_rows)): or something

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can't do reversed(enumerate(...)), because enumerate returns a generator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well that's tedious isn't it. Might be worth saying that we're walking backwards to find the last row of the set of rows with the second to last stream ID. Or something. But with better words. Mebs. 🤷‍♂️

if state_rows[idx - 1][0] <= upper_limit:
state_rows = state_rows[:idx]
break
else:
# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
upper_limit += 1
state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, limit=None
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably move this logic into get_all_updated_current_state_deltas, as I think its basically broken if it doesn't return all rows for a given stream ID


# FIXME: is it a given that there is only one row per stream_id in the
# state_deltas table (so that we can be sure that we have got all of the
# rows for upper_limit)?
limited = True

# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory

class HomeServer(object):
@property
Expand Down Expand Up @@ -121,3 +122,7 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass
14 changes: 11 additions & 3 deletions synapse/storage/data_stores/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,15 +1084,23 @@ def get_all_new_backfill_event_rows(txn):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)

def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, limit: Optional[int]
):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (from_token, to_token, limit))
params = [from_token, to_token]

if limit is not None:
sql += "LIMIT ?"
params.append(limit)

txn.execute(sql, params)
return txn.fetchall()

return self.db.runInteraction(
Expand Down
41 changes: 24 additions & 17 deletions tests/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional

from mock import Mock
import logging
from typing import Any, Dict, List, Optional, Tuple

import attr

Expand All @@ -25,6 +24,7 @@

from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
Expand Down Expand Up @@ -65,9 +65,7 @@ def prepare(self, reactor, clock, hs):
# databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db

self.test_handler = Mock(
wraps=TestReplicationDataHandler(self.worker_hs.get_datastore())
)
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler

repl_handler = ReplicationCommandHandler(self.worker_hs)
Expand All @@ -78,6 +76,9 @@ def prepare(self, reactor, clock, hs):
self._client_transport = None
self._server_transport = None

def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())

def reconnect(self):
if self._client_transport:
self.client.close()
Expand Down Expand Up @@ -174,22 +175,28 @@ def assert_request_is_get_repl_stream_updates(
class TestReplicationDataHandler(ReplicationDataHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""

def __init__(self, hs):
super().__init__(hs)
self.streams = set()
self._received_rdata_rows = []
def __init__(self, store: BaseSlavedStore):
super().__init__(store)

# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]

# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]

def get_streams_to_replicate(self):
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
return self.stream_positions

async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self._received_rdata_rows.append((stream_name, token, r))
self.received_rdata_rows.append((stream_name, token, r))

if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token


@attr.s()
Expand Down Expand Up @@ -221,7 +228,7 @@ def __init__(self, reactor: IReactorTime):
super().__init__()
self.reactor = reactor

self._pull_to_push_producer = None
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]

def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
Expand Down
Loading