diff --git a/mautrix_telegram/db/disappearing_message.py b/mautrix_telegram/db/disappearing_message.py index 6c217173..6487bec3 100644 --- a/mautrix_telegram/db/disappearing_message.py +++ b/mautrix_telegram/db/disappearing_message.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, ClassVar import asyncpg +import time from mautrix.bridge import AbstractDisappearingMessage from mautrix.types import EventID, RoomID @@ -27,6 +28,7 @@ class DisappearingMessage(AbstractDisappearingMessage): + unqueued_ts: int | None = None db: ClassVar[Database] = fake_db async def insert(self) -> None: @@ -50,6 +52,40 @@ async def delete(self) -> None: def _from_row(cls, row: asyncpg.Record) -> DisappearingMessage: return cls(**row) + """ + Get all scheduled messages that will expire in given seconds that haven't yet been unqueued. + + This will also stamp them in the database for being unqueued so every time this method is called + there should be a unique set of events. If seconds is None then all events will be returned + regardless of being requested before. + + The first call on startup should be with None and subsequent with the previous value. + """ + @classmethod + async def unqueue_expiring(cls, seconds: int | None = None) -> list[DisappearingMessage]: + unqueued_ts = int(time.time() * 1000) + + rows = None + if seconds is None: + q = """ + SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message + WHERE expiration_ts <= $1 + """ + rows = await cls.db.fetch(q, unqueued_ts) + else: + q = """ + SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message + WHERE expiration_ts <= $1 AND (unqueued_ts IS NULL OR unqueued_ts < $2) + """ + rows = await cls.db.fetch(q, unqueued_ts + (seconds * 1000), unqueued_ts) + + msgs = [cls._from_row(r) for r in rows] + for msg in msgs: + msg.unqueued_ts = unqueued_ts + await msg.update() + + return msgs + @classmethod async def get(cls, room_id: RoomID, event_id: EventID) -> DisappearingMessage | None: q = """ @@ -63,16 +99,10 @@ async def get(cls, room_id: RoomID, event_id: EventID) -> DisappearingMessage | @classmethod async def get_all_scheduled(cls) -> list[DisappearingMessage]: - q = """ - SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message - WHERE expiration_ts IS NOT NULL - """ - return [cls._from_row(r) for r in await cls.db.fetch(q)] + # Stubbed because we pump with unqueue_expiring + return [] @classmethod async def get_unscheduled_for_room(cls, room_id: RoomID) -> list[DisappearingMessage]: - q = """ - SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message - WHERE room_id = $1 AND expiration_ts IS NULL - """ - return [cls._from_row(r) for r in await cls.db.fetch(q, room_id)] + # Stubbed because we pump with unqueue_expiring + return [] diff --git a/mautrix_telegram/db/upgrade/__init__.py b/mautrix_telegram/db/upgrade/__init__.py index 34cedc6e..a8f0b9e8 100644 --- a/mautrix_telegram/db/upgrade/__init__.py +++ b/mautrix_telegram/db/upgrade/__init__.py @@ -21,4 +21,5 @@ v16_backfill_type, v17_message_find_recent, v18_puppet_contact_info_set, + v19_disappearing_message_unqueue, ) diff --git a/mautrix_telegram/db/upgrade/v00_latest_revision.py b/mautrix_telegram/db/upgrade/v00_latest_revision.py index 66c1989a..c65e78dd 100644 --- a/mautrix_telegram/db/upgrade/v00_latest_revision.py +++ b/mautrix_telegram/db/upgrade/v00_latest_revision.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from mautrix.util.async_db import Connection, Scheme -latest_version = 18 +latest_version = 19 async def create_latest_tables(conn: Connection, scheme: Scheme) -> int: @@ -92,10 +92,12 @@ async def create_latest_tables(conn: Connection, scheme: Scheme) -> int: event_id TEXT, expiration_seconds BIGINT, expiration_ts BIGINT, + unqueued_ts BIGINT, PRIMARY KEY (room_id, event_id) )""" ) + await conn.execute("CREATE INDEX disappearing_message_expiration_ts ON disappearing_message(expiration_ts)") await conn.execute( """CREATE TABLE puppet ( id BIGINT PRIMARY KEY, diff --git a/mautrix_telegram/db/upgrade/v19_disappearing_message_unqueue.py b/mautrix_telegram/db/upgrade/v19_disappearing_message_unqueue.py new file mode 100644 index 00000000..89c8780d --- /dev/null +++ b/mautrix_telegram/db/upgrade/v19_disappearing_message_unqueue.py @@ -0,0 +1,26 @@ +# mautrix-telegram - A Matrix-Telegram puppeting bridge +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from mautrix.util.async_db import Connection + +from . import upgrade_table + + +@upgrade_table.register(description="Add index on disappearing_message expiration_ts, unqueued_ts column") +async def upgrade_v19(conn: Connection) -> None: + await conn.execute( + "ALTER TABLE disappearing_message ADD COLUMN unqueued_ts BIGINT" + ) + await conn.execute("CREATE INDEX disappearing_message_expiration_ts ON disappearing_message(expiration_ts)") diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 2ab062db..97a45216 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -322,6 +322,8 @@ class Portal(DBPortal, BasePortal): _msg_conv: putil.TelegramMessageConverter + _disappearing_event: asyncio.Event + def __init__( self, tgid: TelegramID, @@ -468,6 +470,42 @@ def set_dm_room_metadata(self) -> bool: or (self.encrypted and self.private_chat_portal_meta != "never") ) + @classmethod + async def _disappearing_message_loop(cls, seconds: int | None = None) -> None: + try: + seconds = None + while True: + print("fetching disappearing") + cls._disappearing_event.clear() + msgs = await cls.disappearing_msg_class.unqueue_expiring(seconds) + print(f"got {len(msgs)} rows") + for msg in msgs: + print("handling disappear thing") + portal = await cls.bridge.get_portal(msg.room_id) + if portal and portal.mxid: + background_task.create(portal._disappear_event(msg)) + else: + await msg.delete() + + try: + await asyncio.wait_for(cls._disappearing_event.wait(), 10) + except TimeoutError: + pass + + seconds = 10 + except RuntimeError: + return + + @classmethod + async def restart_scheduled_disappearing(cls) -> None: + cls._disappearing_event = asyncio.Event() + background_task.create(cls._disappearing_message_loop()) + + @classmethod + async def notify_disappearing_message_loop(cls) -> None: + print("notifying disappear loop") + cls._disappearing_event.set() + @classmethod def init_cls(cls, bridge: "TelegramBridge") -> None: BasePortal.bridge = bridge @@ -3531,7 +3569,7 @@ async def _mark_disappearing( ) await dm.insert() if expires_at: - background_task.create(self._disappear_event(dm)) + Portal.notify_disappearing_message_loop() async def _create_room_on_action( self, source: au.AbstractUser, action: TypeMessageAction diff --git a/mautrix_telegram/version.py b/mautrix_telegram/version.py index 0b226803..ccb8cd1d 100644 --- a/mautrix_telegram/version.py +++ b/mautrix_telegram/version.py @@ -1 +1,6 @@ -from .get_version import git_revision, git_tag, linkified_version, version +# Generated in setup.py + +git_tag = None +git_revision = "e3a067c2" +version = "0.11.3+dev.e3a067c2" +linkified_version = "0.11.3+dev.[e3a067c2](https://github.com/mautrix/telegram/commit/e3a067c27aa3d9dd5e82db307218cc66c8356ddd)"