Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make MultiWriterIDGenerator work for streams that use negative stream…
Browse files Browse the repository at this point in the history
… IDs (#8203)

This is so that we can use it for the backfill events stream.
  • Loading branch information
erikjohnston authored Sep 1, 2020
1 parent 318245e commit bbb3c86
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog.d/8203.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `MultiWriterIDGenerator` work for streams that use negative values.
39 changes: 28 additions & 11 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""

def __init__(
Expand All @@ -196,13 +198,19 @@ def __init__(
instance_column: str,
id_column: str,
sequence_name: str,
positive: bool = True,
):
self._db = db
self._instance_name = instance_name
self._positive = positive
self._return_factor = 1 if positive else -1

# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()

# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
Expand Down Expand Up @@ -233,13 +241,16 @@ def __init__(
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
# If positive stream aggregate via MAX. For negative stream use MIN
# *and* negate the result to get a positive number.
sql = """
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
}

cur = db_conn.cursor()
Expand Down Expand Up @@ -269,15 +280,16 @@ async def get_next(self):
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
assert self.get_current_token_for_writer(self._instance_name) < next_id

with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id

self._unfinished_ids.add(next_id)

@contextlib.contextmanager
def manager():
try:
yield next_id
# Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)

Expand All @@ -296,15 +308,15 @@ async def get_next_mult(self, n: int):
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)

with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)

self._unfinished_ids.update(next_ids)

@contextlib.contextmanager
def manager():
try:
yield next_ids
yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
Expand All @@ -327,7 +339,7 @@ def get_next_txn(self, txn: LoggingTransaction):
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)

return next_id
return self._return_factor * next_id

def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
Expand Down Expand Up @@ -359,20 +371,25 @@ def get_current_token_for_writer(self, instance_name: str) -> int:
"""

with self._lock:
return self._current_positions.get(instance_name, 0)
return self._return_factor * self._current_positions.get(instance_name, 0)

def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""

with self._lock:
return dict(self._current_positions)
return {
name: self._return_factor * i
for name, i in self._current_positions.items()
}

def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""

new_id *= self._return_factor

with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
Expand All @@ -390,7 +407,7 @@ def get_persisted_upto_position(self) -> int:
"""

with self._lock:
return self._persisted_upto_position
return self._return_factor * self._persisted_upto_position

def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
Expand Down
105 changes: 105 additions & 0 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,108 @@ def test_get_persisted_upto_position_get_next(self):
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).


class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
"""

if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)

def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
positive=False,
)

return self.get_success(self.db_pool.runWithConnection(_create))

def _insert_row(self, instance_name: str, stream_id: int):
"""Insert one row as the given instance with given stream_id.
"""

def _insert(txn):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)

self.get_success(self.db_pool.runInteraction("_insert_row", _insert))

def test_single_instance(self):
"""Test that reads and writes from a single process are handled
correctly.
"""
id_gen = self._create_id_generator()

with self.get_success(id_gen.get_next()) as stream_id:
self._insert_row("master", stream_id)

self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)

with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)

self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(id_gen.get_persisted_upto_position(), -4)

# Test loading from DB by creating a second ID gen
second_id_gen = self._create_id_generator()

self.assertEqual(second_id_gen.get_positions(), {"master": -4})
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)

def test_multiple_instance(self):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
id_gen_1 = self._create_id_generator("first")
id_gen_2 = self._create_id_generator("second")

with self.get_success(id_gen_1.get_next()) as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)

self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)

with self.get_success(id_gen_2.get_next()) as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)

self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)

0 comments on commit bbb3c86

Please sign in to comment.