diff --git a/.circleci/config.yml b/.circleci/config.yml index 98c217dd1d93..5bd2ab2b7656 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,18 +4,16 @@ jobs: machine: true steps: - checkout - - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} -t matrixdotorg/synapse:${CIRCLE_TAG}-py3 . + - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} . - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - run: docker push matrixdotorg/synapse:${CIRCLE_TAG} - - run: docker push matrixdotorg/synapse:${CIRCLE_TAG}-py3 dockerhubuploadlatest: machine: true steps: - checkout - - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest -t matrixdotorg/synapse:latest-py3 . + - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest . - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - run: docker push matrixdotorg/synapse:latest - - run: docker push matrixdotorg/synapse:latest-py3 workflows: version: 2 diff --git a/CHANGES.md b/CHANGES.md index 6c986808eb3e..a2c8232be1d7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,72 @@ +Synapse 1.19.0rc1 (2020-08-13) +============================== + +Removal warning +--------------- + +As outlined in the [previous release](https://github.com/matrix-org/synapse/releases/tag/v1.18.0), we are no longer publishing Docker images with the `-py3` tag suffix. On top of that, we have also removed the `latest-py3` tag. Please see [the announcement in the upgrade notes for 1.18.0](https://github.com/matrix-org/synapse/blob/develop/UPGRADE.rst#upgrading-to-v1180). + + +Features +-------- + +- Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus. ([\#7902](https://github.com/matrix-org/synapse/issues/7902)) +- Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms//delete`). Contributed by @dklimpel. ([\#7964](https://github.com/matrix-org/synapse/issues/7964)) +- Add rate limiting to users joining rooms. ([\#8008](https://github.com/matrix-org/synapse/issues/8008)) +- Add a `/health` endpoint to every configured HTTP listener that can be used as a health check endpoint by load balancers. ([\#8048](https://github.com/matrix-org/synapse/issues/8048)) +- Allow login to be blocked based on the values of SAML attributes. ([\#8052](https://github.com/matrix-org/synapse/issues/8052)) +- Allow guest access to the `GET /_matrix/client/r0/rooms/{room_id}/members` endpoint, according to MSC2689. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7314](https://github.com/matrix-org/synapse/issues/7314)) + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory. ([\#7977](https://github.com/matrix-org/synapse/issues/7977)) +- Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured. ([\#7978](https://github.com/matrix-org/synapse/issues/7978)) +- Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0. ([\#7980](https://github.com/matrix-org/synapse/issues/7980)) +- Fix various comments and minor discrepencies in server notices code. ([\#7996](https://github.com/matrix-org/synapse/issues/7996)) +- Fix a long standing bug where HTTP HEAD requests resulted in a 400 error. ([\#7999](https://github.com/matrix-org/synapse/issues/7999)) +- Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger. ([\#8011](https://github.com/matrix-org/synapse/issues/8011), [\#8012](https://github.com/matrix-org/synapse/issues/8012)) + + +Updates to the Docker image +--------------------------- + +- We no longer publish Docker images with the `-py3` tag suffix, as [announced in the upgrade notes](https://github.com/matrix-org/synapse/blob/develop/UPGRADE.rst#upgrading-to-v1180). ([\#8056](https://github.com/matrix-org/synapse/issues/8056)) + + +Improved Documentation +---------------------- + +- Document how to set up a client .well-known file and fix several pieces of outdated documentation. ([\#7899](https://github.com/matrix-org/synapse/issues/7899)) +- Improve workers docs. ([\#7990](https://github.com/matrix-org/synapse/issues/7990), [\#8000](https://github.com/matrix-org/synapse/issues/8000)) +- Fix typo in `docs/workers.md`. ([\#7992](https://github.com/matrix-org/synapse/issues/7992)) +- Add documentation for how to undo a room shutdown. ([\#7998](https://github.com/matrix-org/synapse/issues/7998), [\#8010](https://github.com/matrix-org/synapse/issues/8010)) + + +Internal Changes +---------------- + +- Reduce the amount of whitespace in JSON stored and sent in responses. Contributed by David Vo. ([\#7372](https://github.com/matrix-org/synapse/issues/7372)) +- Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0. ([\#7936](https://github.com/matrix-org/synapse/issues/7936), [\#7979](https://github.com/matrix-org/synapse/issues/7979)) +- Convert various parts of the codebase to async/await. ([\#7947](https://github.com/matrix-org/synapse/issues/7947), [\#7948](https://github.com/matrix-org/synapse/issues/7948), [\#7949](https://github.com/matrix-org/synapse/issues/7949), [\#7951](https://github.com/matrix-org/synapse/issues/7951), [\#7963](https://github.com/matrix-org/synapse/issues/7963), [\#7973](https://github.com/matrix-org/synapse/issues/7973), [\#7975](https://github.com/matrix-org/synapse/issues/7975), [\#7976](https://github.com/matrix-org/synapse/issues/7976), [\#7981](https://github.com/matrix-org/synapse/issues/7981), [\#7987](https://github.com/matrix-org/synapse/issues/7987), [\#7989](https://github.com/matrix-org/synapse/issues/7989), [\#8003](https://github.com/matrix-org/synapse/issues/8003), [\#8014](https://github.com/matrix-org/synapse/issues/8014), [\#8016](https://github.com/matrix-org/synapse/issues/8016), [\#8027](https://github.com/matrix-org/synapse/issues/8027), [\#8031](https://github.com/matrix-org/synapse/issues/8031), [\#8032](https://github.com/matrix-org/synapse/issues/8032), [\#8035](https://github.com/matrix-org/synapse/issues/8035), [\#8042](https://github.com/matrix-org/synapse/issues/8042), [\#8044](https://github.com/matrix-org/synapse/issues/8044), [\#8045](https://github.com/matrix-org/synapse/issues/8045), [\#8061](https://github.com/matrix-org/synapse/issues/8061), [\#8062](https://github.com/matrix-org/synapse/issues/8062), [\#8063](https://github.com/matrix-org/synapse/issues/8063), [\#8066](https://github.com/matrix-org/synapse/issues/8066), [\#8069](https://github.com/matrix-org/synapse/issues/8069), [\#8070](https://github.com/matrix-org/synapse/issues/8070)) +- Move some database-related log lines from the default logger to the database/transaction loggers. ([\#7952](https://github.com/matrix-org/synapse/issues/7952)) +- Add a script to detect source code files using non-unix line terminators. ([\#7965](https://github.com/matrix-org/synapse/issues/7965), [\#7970](https://github.com/matrix-org/synapse/issues/7970)) +- Log the SAML session ID during creation. ([\#7971](https://github.com/matrix-org/synapse/issues/7971)) +- Implement new experimental push rules for some users. ([\#7997](https://github.com/matrix-org/synapse/issues/7997)) +- Remove redundant and unreliable signature check for v1 Identity Service lookup responses. ([\#8001](https://github.com/matrix-org/synapse/issues/8001)) +- Improve the performance of the register endpoint. ([\#8009](https://github.com/matrix-org/synapse/issues/8009)) +- Reduce less useful output in the newsfragment CI step. Add a link to the changelog section of the contributing guide on error. ([\#8024](https://github.com/matrix-org/synapse/issues/8024)) +- Rename storage layer objects to be more sensible. ([\#8033](https://github.com/matrix-org/synapse/issues/8033)) +- Change the default log config to reduce disk I/O and storage for new servers. ([\#8040](https://github.com/matrix-org/synapse/issues/8040)) +- Add an assertion on `prev_events` in `create_new_client_event`. ([\#8041](https://github.com/matrix-org/synapse/issues/8041)) +- Add a comment to `ServerContextFactory` about the use of `SSLv23_METHOD`. ([\#8043](https://github.com/matrix-org/synapse/issues/8043)) +- Log `OPTIONS` requests at `DEBUG` rather than `INFO` level to reduce amount logged at `INFO`. ([\#8049](https://github.com/matrix-org/synapse/issues/8049)) +- Reduce amount of outbound request logging at `INFO` level. ([\#8050](https://github.com/matrix-org/synapse/issues/8050)) +- It is no longer necessary to explicitly define `filters` in the logging configuration. (Continuing to do so is redundant but harmless.) ([\#8051](https://github.com/matrix-org/synapse/issues/8051)) +- Add and improve type hints. ([\#8058](https://github.com/matrix-org/synapse/issues/8058), [\#8064](https://github.com/matrix-org/synapse/issues/8064), [\#8060](https://github.com/matrix-org/synapse/issues/8060), [\#8067](https://github.com/matrix-org/synapse/issues/8067)) + + Synapse 1.18.0 (2020-07-30) =========================== diff --git a/changelog.d/7314.misc b/changelog.d/7314.misc deleted file mode 100644 index 30720100c298..000000000000 --- a/changelog.d/7314.misc +++ /dev/null @@ -1 +0,0 @@ -Allow guest access to the `GET /_matrix/client/r0/rooms/{room_id}/members` endpoint, according to MSC2689. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature deleted file mode 100644 index feb02be234bf..000000000000 --- a/changelog.d/7736.feature +++ /dev/null @@ -1 +0,0 @@ -Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). diff --git a/changelog.d/7864.bugfix b/changelog.d/7864.bugfix new file mode 100644 index 000000000000..8623355fe921 --- /dev/null +++ b/changelog.d/7864.bugfix @@ -0,0 +1 @@ +Fix a memory leak by limiting the length of time that messages will be queued for a remote server that has been unreachable. diff --git a/changelog.d/7899.doc b/changelog.d/7899.doc deleted file mode 100644 index 847c2cb62c4f..000000000000 --- a/changelog.d/7899.doc +++ /dev/null @@ -1 +0,0 @@ -Document how to set up a Client Well-Known file and fix several pieces of outdated documentation. diff --git a/changelog.d/7902.feature b/changelog.d/7902.feature deleted file mode 100644 index 4feae8cc2955..000000000000 --- a/changelog.d/7902.feature +++ /dev/null @@ -1 +0,0 @@ -Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus. diff --git a/changelog.d/7936.misc b/changelog.d/7936.misc deleted file mode 100644 index 4304bbdd2597..000000000000 --- a/changelog.d/7936.misc +++ /dev/null @@ -1 +0,0 @@ -Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0. diff --git a/changelog.d/7948.misc b/changelog.d/7948.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7948.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7949.misc b/changelog.d/7949.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7949.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7951.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7952.misc b/changelog.d/7952.misc deleted file mode 100644 index 93c25cb386bd..000000000000 --- a/changelog.d/7952.misc +++ /dev/null @@ -1 +0,0 @@ -Move some database-related log lines from the default logger to the database/transaction loggers. \ No newline at end of file diff --git a/changelog.d/7963.misc b/changelog.d/7963.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7963.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7964.feature b/changelog.d/7964.feature deleted file mode 100644 index ffe861650ce2..000000000000 --- a/changelog.d/7964.feature +++ /dev/null @@ -1 +0,0 @@ -Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms//delete`). Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/7965.misc b/changelog.d/7965.misc deleted file mode 100644 index ee9f1a7114a8..000000000000 --- a/changelog.d/7965.misc +++ /dev/null @@ -1 +0,0 @@ -Add a script to detect source code files using non-unix line terminators. \ No newline at end of file diff --git a/changelog.d/7970.misc b/changelog.d/7970.misc deleted file mode 100644 index ee9f1a7114a8..000000000000 --- a/changelog.d/7970.misc +++ /dev/null @@ -1 +0,0 @@ -Add a script to detect source code files using non-unix line terminators. \ No newline at end of file diff --git a/changelog.d/7971.misc b/changelog.d/7971.misc deleted file mode 100644 index 87a4eb1f4d66..000000000000 --- a/changelog.d/7971.misc +++ /dev/null @@ -1 +0,0 @@ -Log the SAML session ID during creation. diff --git a/changelog.d/7973.misc b/changelog.d/7973.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7973.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7975.misc b/changelog.d/7975.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7975.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7976.misc b/changelog.d/7976.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7976.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7977.bugfix b/changelog.d/7977.bugfix deleted file mode 100644 index c587f1305567..000000000000 --- a/changelog.d/7977.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory. diff --git a/changelog.d/7978.bugfix b/changelog.d/7978.bugfix deleted file mode 100644 index 247b18db20ba..000000000000 --- a/changelog.d/7978.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured. diff --git a/changelog.d/7979.misc b/changelog.d/7979.misc deleted file mode 100644 index 4304bbdd2597..000000000000 --- a/changelog.d/7979.misc +++ /dev/null @@ -1 +0,0 @@ -Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0. diff --git a/changelog.d/7980.bugfix b/changelog.d/7980.bugfix deleted file mode 100644 index fa351b4b7744..000000000000 --- a/changelog.d/7980.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0. diff --git a/changelog.d/7981.misc b/changelog.d/7981.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7981.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7987.misc b/changelog.d/7987.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7987.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7989.misc b/changelog.d/7989.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/7989.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7990.doc b/changelog.d/7990.doc deleted file mode 100644 index 8d8fd926e93c..000000000000 --- a/changelog.d/7990.doc +++ /dev/null @@ -1 +0,0 @@ -Improve workers docs. diff --git a/changelog.d/7992.doc b/changelog.d/7992.doc deleted file mode 100644 index 3368fb59126a..000000000000 --- a/changelog.d/7992.doc +++ /dev/null @@ -1 +0,0 @@ -Fix typo in `docs/workers.md`. diff --git a/changelog.d/7996.bugfix b/changelog.d/7996.bugfix deleted file mode 100644 index 1e51f2055829..000000000000 --- a/changelog.d/7996.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix various comments and minor discrepencies in server notices code. diff --git a/changelog.d/7998.doc b/changelog.d/7998.doc deleted file mode 100644 index fc8b3f0c3df5..000000000000 --- a/changelog.d/7998.doc +++ /dev/null @@ -1 +0,0 @@ -Add documentation for how to undo a room shutdown. diff --git a/changelog.d/7999.bugfix b/changelog.d/7999.bugfix deleted file mode 100644 index e0b8c4922f86..000000000000 --- a/changelog.d/7999.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long standing bug where HTTP HEAD requests resulted in a 400 error. diff --git a/changelog.d/8001.misc b/changelog.d/8001.misc deleted file mode 100644 index 0be4b37d22c5..000000000000 --- a/changelog.d/8001.misc +++ /dev/null @@ -1 +0,0 @@ -Remove redundant and unreliable signature check for v1 Identity Service lookup responses. diff --git a/changelog.d/8003.misc b/changelog.d/8003.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/8003.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/8008.feature b/changelog.d/8008.feature deleted file mode 100644 index c6d381809aaf..000000000000 --- a/changelog.d/8008.feature +++ /dev/null @@ -1 +0,0 @@ -Add rate limiting to users joining rooms. diff --git a/changelog.d/8011.bugfix b/changelog.d/8011.bugfix deleted file mode 100644 index c673040de938..000000000000 --- a/changelog.d/8011.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger. diff --git a/changelog.d/8012.bugfix b/changelog.d/8012.bugfix deleted file mode 100644 index c673040de938..000000000000 --- a/changelog.d/8012.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger. diff --git a/changelog.d/8014.misc b/changelog.d/8014.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/8014.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/8016.misc b/changelog.d/8016.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/8016.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/8024.misc b/changelog.d/8024.misc deleted file mode 100644 index 4bc739502bf6..000000000000 --- a/changelog.d/8024.misc +++ /dev/null @@ -1 +0,0 @@ -Reduce less useful output in the newsfragment CI step. Add a link to the changelog section of the contributing guide on error. \ No newline at end of file diff --git a/changelog.d/8027.misc b/changelog.d/8027.misc deleted file mode 100644 index dfe4c03171d6..000000000000 --- a/changelog.d/8027.misc +++ /dev/null @@ -1 +0,0 @@ -Convert various parts of the codebase to async/await. diff --git a/changelog.d/7947.misc b/changelog.d/8072.misc similarity index 100% rename from changelog.d/7947.misc rename to changelog.d/8072.misc diff --git a/changelog.d/8081.bugfix b/changelog.d/8081.bugfix new file mode 100644 index 000000000000..9ebcbf5b8448 --- /dev/null +++ b/changelog.d/8081.bugfix @@ -0,0 +1 @@ +Fix `Re-starting finished log context PUT-nnnn` warning when event persistence failed. diff --git a/docker/conf/log.config b/docker/conf/log.config index ed418a57cd9b..491bbcc87ad7 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -4,16 +4,10 @@ formatters: precise: format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - handlers: console: class: logging.StreamHandler formatter: precise - filters: [context] loggers: synapse.storage.SQL: diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md index 2ff552bcb34f..9b1cb1c184b1 100644 --- a/docs/admin_api/shutdown_room.md +++ b/docs/admin_api/shutdown_room.md @@ -79,13 +79,20 @@ Response: the structure can and does change without notice. First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it -never happened - work has to be done to move forward instead of resetting the past. +never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible +to recover at all: -1. For safety reasons, it is recommended to shut down Synapse prior to continuing. +* If the room was invite-only, your users will need to be re-invited. +* If the room no longer has any members at all, it'll be impossible to rejoin. +* The first user to rejoin will have to do so via an alias on a different server. + +With all that being said, if you still want to try and recover the room: + +1. For safety reasons, shut down Synapse. 2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. * The room ID is the same one supplied to the shutdown room API, not the Content Violation room. -3. Restart Synapse (required). +3. Restart Synapse. You will have to manually handle, if you so choose, the following: diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 7bfb96eff623..fd48ba0874c2 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -139,3 +139,10 @@ client IP addresses are recorded correctly. Having done so, you can then use `https://matrix.example.com` (instead of `https://matrix.example.com:8448`) as the "Custom server" when connecting to Synapse from a client. + + +## Health check endpoint + +Synapse exposes a health check endpoint for use by reverse proxies. +Each configured HTTP listener has a `/health` endpoint which always returns +200 OK (and doesn't get logged). diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index fe85978a1fb1..9235b89fb1c0 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1577,6 +1577,17 @@ saml2_config: # #grandfathered_mxid_source_attribute: upn + # It is possible to configure Synapse to only allow logins if SAML attributes + # match particular values. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. + # + #attribute_requirements: + # - attribute: userGroup + # value: "staff" + # - attribute: department + # value: "sales" + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 1a2739455ef2..55a48a9ed622 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -11,24 +11,33 @@ formatters: precise: format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - handlers: file: - class: logging.handlers.RotatingFileHandler + class: logging.handlers.TimedRotatingFileHandler formatter: precise filename: /var/log/matrix-synapse/homeserver.log - maxBytes: 104857600 - backupCount: 10 - filters: [context] + when: midnight + backupCount: 3 # Does not include the current log file. encoding: utf8 + + # Default to buffering writes to log file for efficiency. This means that + # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR + # logs will still be flushed immediately. + buffer: + class: logging.handlers.MemoryHandler + target: file + # The capacity is the number of log lines that are buffered before + # being written to disk. Increasing this will lead to better + # performance, at the expensive of it taking longer for log lines to + # be written to disk. + capacity: 10 + flushLevel: 30 # Flush for WARNING logs as well + + # A handler that writes logs to stderr. Unused by default, but can be used + # instead of "buffer" and "file" in the logger handlers. console: class: logging.StreamHandler formatter: precise - filters: [context] loggers: synapse.storage.SQL: @@ -36,8 +45,23 @@ loggers: # information such as access tokens. level: INFO + twisted: + # We send the twisted logging directly to the file handler, + # to work around https://github.com/matrix-org/synapse/issues/3471 + # when using "buffer" logger. Use "console" to log to stderr instead. + handlers: [file] + propagate: false + root: level: INFO - handlers: [file, console] + + # Write logs to the `buffer` handler, which will buffer them together in memory, + # then write them to a file. + # + # Replace "buffer" with "console" to log to stderr instead. (Note that you'll + # also need to update the configuation for the `twisted` logger above, in + # this case.) + # + handlers: [buffer] disable_existing_loggers: false diff --git a/docs/systemd-with-workers/workers/federation_reader.yaml b/docs/systemd-with-workers/workers/federation_reader.yaml index 5b65c7040d54..13e69e62c9db 100644 --- a/docs/systemd-with-workers/workers/federation_reader.yaml +++ b/docs/systemd-with-workers/workers/federation_reader.yaml @@ -1,7 +1,7 @@ worker_app: synapse.app.federation_reader +worker_name: federation_reader1 worker_replication_host: 127.0.0.1 -worker_replication_port: 9092 worker_replication_http_port: 9093 worker_listeners: diff --git a/docs/user_directory.md b/docs/user_directory.md index 37dc71e751cf..872fc2197968 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after DB corruption) get stale or out of sync. If this happens, for now the -solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) +solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql) and then restart synapse. This should then start a background task to flush the current tables and regenerate the directory. diff --git a/docs/workers.md b/docs/workers.md index 80b65a0cec2a..bfec745897c2 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -23,7 +23,7 @@ The processes communicate with each other via a Synapse-specific protocol called feeds streams of newly written data between processes so they can be kept in sync with the database state. -When configured to do so, Synapse uses a +When configured to do so, Synapse uses a [Redis pub/sub channel](https://redis.io/topics/pubsub) to send the replication stream between all configured Synapse processes. Additionally, processes may make HTTP requests to each other, primarily for operations which need to wait @@ -66,23 +66,31 @@ https://hub.docker.com/r/matrixdotorg/synapse/. To make effective use of the workers, you will need to configure an HTTP reverse-proxy such as nginx or haproxy, which will direct incoming requests to -the correct worker, or to the main synapse instance. See +the correct worker, or to the main synapse instance. See [reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse proxy. -To enable workers you should create a configuration file for each worker -process. Each worker configuration file inherits the configuration of the shared -homeserver configuration file. You can then override configuration specific to -that worker, e.g. the HTTP listener that it provides (if any); logging -configuration; etc. You should minimise the number of overrides though to -maintain a usable config. +When using workers, each worker process has its own configuration file which +contains settings specific to that worker, such as the HTTP listener that it +provides (if any), logging configuration, etc. +Normally, the worker processes are configured to read from a shared +configuration file as well as the worker-specific configuration files. This +makes it easier to keep common configuration settings synchronised across all +the processes. -### Shared Configuration +The main process is somewhat special in this respect: it does not normally +need its own configuration file and can take all of its configuration from the +shared configuration file. + + +### Shared configuration + +Normally, only a couple of changes are needed to make an existing configuration +file suitable for use with workers. First, you need to enable an "HTTP replication +listener" for the main process; and secondly, you need to enable redis-based +replication. For example: -Next you need to add both a HTTP replication listener, used for HTTP requests -between processes, and redis config to the shared Synapse configuration file -(`homeserver.yaml`). For example: ```yaml # extend the existing `listeners` section. This defines the ports that the @@ -105,7 +113,7 @@ Under **no circumstances** should the replication listener be exposed to the public internet; it has no authentication and is unencrypted. -### Worker Configuration +### Worker configuration In the config file for each worker, you must specify the type of worker application (`worker_app`), and you should specify a unqiue name for the worker @@ -145,6 +153,9 @@ plain HTTP endpoint on port 8083 separately serving various endpoints, e.g. Obviously you should configure your reverse-proxy to route the relevant endpoints to the worker (`localhost:8083` in the above example). + +### Running Synapse with workers + Finally, you need to start your worker processes. This can be done with either `synctl` or your distribution's preferred service manager such as `systemd`. We recommend the use of `systemd` where available: for information on setting up @@ -407,6 +418,23 @@ all these to be folded into the `generic_worker` app and to use config to define which processes handle the various proccessing such as push notifications. +## Migration from old config + +There are two main independent changes that have been made: introducing Redis +support and merging apps into `synapse.app.generic_worker`. Both these changes +are backwards compatible and so no changes to the config are required, however +server admins are encouraged to plan to migrate to Redis as the old style direct +TCP replication config is deprecated. + +To migrate to Redis add the `redis` config as above, and optionally remove the +TCP `replication` listener from master and `worker_replication_port` from worker +config. + +To migrate apps to use `synapse.app.generic_worker` simply update the +`worker_app` option in the worker configs, and where worker are started (e.g. +in systemd service files, but not required for synctl). + + ## Architectural diagram The following shows an example setup using Redis and a reverse proxy: diff --git a/mypy.ini b/mypy.ini index a61009b1971f..c69cb5dc4064 100644 --- a/mypy.ini +++ b/mypy.ini @@ -81,3 +81,6 @@ ignore_missing_imports = True [mypy-rust_python_jaeger_reporter.*] ignore_missing_imports = True + +[mypy-nacl.*] +ignore_missing_imports = True diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 94aa8758b48f..56365e2b58bf 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -40,7 +40,7 @@ class MockHomeserver(HomeServer): config.server_name, reactor=reactor, config=config, **kwargs ) - self.version_string = "Synapse/"+get_version_string(synapse) + self.version_string = "Synapse/" + get_version_string(synapse) if __name__ == "__main__": @@ -86,7 +86,7 @@ if __name__ == "__main__": store = hs.get_datastore() async def run_background_updates(): - await store.db.updates.run_background_updates(sleep=False) + await store.db_pool.updates.run_background_updates(sleep=False) # Stop the reactor to exit the script once every background update is run. reactor.stop() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index bee525197fb2..a34bdf18302c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -35,31 +35,29 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore -from synapse.storage.data_stores.main.deviceinbox import ( - DeviceInboxBackgroundUpdateStore, -) -from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore -from synapse.storage.data_stores.main.events_bg_updates import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore +from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore +from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, ) -from synapse.storage.data_stores.main.media_repository import ( +from synapse.storage.databases.main.media_repository import ( MediaRepositoryBackgroundUpdateStore, ) -from synapse.storage.data_stores.main.registration import ( +from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) -from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore -from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore -from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore -from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore -from synapse.storage.data_stores.main.stats import StatsStore -from synapse.storage.data_stores.main.user_directory import ( +from synapse.storage.databases.main.room import RoomBackgroundUpdateStore +from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore +from synapse.storage.databases.main.search import SearchBackgroundUpdateStore +from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.database import Database, make_conn +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -69,7 +67,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url", "count_as_unread"], + "events": ["processed", "outlier", "contains_url"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], @@ -175,14 +173,14 @@ class Store( StatsStore, ): def execute(self, f, *args, **kwargs): - return self.db.runInteraction(f.__name__, f, *args, **kwargs) + return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.db.runInteraction("execute_sql", r) + return self.db_pool.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -227,7 +225,7 @@ class Porter(object): async def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = await self.postgres_store.db.simple_select_one( + row = await self.postgres_store.db_pool.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -244,7 +242,7 @@ class Porter(object): ) = await self._setup_sent_transactions() backward_chunk = 0 else: - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -274,7 +272,7 @@ class Porter(object): await self.postgres_store.execute(delete_all) - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -318,7 +316,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -359,7 +357,7 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = await self.sqlite_store.db.runInteraction( + headers, frows, brows = await self.sqlite_store.db_pool.runInteraction( "select", r ) @@ -375,7 +373,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store.db.simple_update_one_txn( + self.postgres_store.db_pool.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -413,7 +411,7 @@ class Porter(object): return headers, rows - headers, rows = await self.sqlite_store.db.runInteraction("select", r) + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -451,7 +449,7 @@ class Porter(object): ], ) - self.postgres_store.db.simple_update_one_txn( + self.postgres_store.db_pool.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -494,7 +492,7 @@ class Porter(object): db_conn, allow_outdated_version=allow_outdated_version ) prepare_database(db_conn, engine, config=self.hs_config) - store = Store(Database(hs, db_config, engine), db_conn, hs) + store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) db_conn.commit() return store @@ -502,7 +500,7 @@ class Porter(object): async def run_background_updates_on_postgres(self): # Manually apply all background updates on the PostgreSQL database. postgres_ready = ( - await self.postgres_store.db.updates.has_completed_background_updates() + await self.postgres_store.db_pool.updates.has_completed_background_updates() ) if not postgres_ready: @@ -511,9 +509,9 @@ class Porter(object): self.progress.set_state("Running background updates on PostgreSQL") while not postgres_ready: - await self.postgres_store.db.updates.do_next_background_update(100) + await self.postgres_store.db_pool.updates.do_next_background_update(100) postgres_ready = await ( - self.postgres_store.db.updates.has_completed_background_updates() + self.postgres_store.db_pool.updates.has_completed_background_updates() ) async def run(self): @@ -534,7 +532,7 @@ class Porter(object): # Check if all background updates are done, abort if not. updates_complete = ( - await self.sqlite_store.db.updates.has_completed_background_updates() + await self.sqlite_store.db_pool.updates.has_completed_background_updates() ) if not updates_complete: end_error = ( @@ -576,22 +574,24 @@ class Porter(object): ) try: - await self.postgres_store.db.runInteraction("alter_table", alter_table) + await self.postgres_store.db_pool.runInteraction( + "alter_table", alter_table + ) except Exception: # On Error Resume Next pass - await self.postgres_store.db.runInteraction( + await self.postgres_store.db_pool.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = await self.sqlite_store.db.simple_select_onecol( + sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = await self.postgres_store.db.simple_select_onecol( + postgres_tables = await self.postgres_store.db_pool.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -692,7 +692,7 @@ class Porter(object): return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = await self.sqlite_store.db.runInteraction("select", r) + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -725,7 +725,7 @@ class Porter(object): next_chunk = await self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -794,14 +794,14 @@ class Porter(object): next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) def _setup_user_id_seq(self): def r(txn): next_id = find_max_generated_user_id_localpart(txn) + 1 txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.db.runInteraction("setup_user_id_seq", r) + return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) ############################################## diff --git a/synapse/__init__.py b/synapse/__init__.py index f70381bc71cd..832a8e2014eb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ except ImportError: pass -__version__ = "1.18.0" +__version__ = "1.19.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 2178e623da8e..d8190f92ab30 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import List, Optional, Tuple import pymacaroons from netaddr import IPAddress -from twisted.internet import defer from twisted.web.server import Request import synapse.types @@ -80,13 +79,14 @@ def __init__(self, hs): self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key - @defer.inlineCallbacks - def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) - auth_events_ids = yield self.compute_auth_events( + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ): + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -94,14 +94,13 @@ def check_from_context(self, room_version: str, event, context, do_sig_check=Tru room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check ) - @defer.inlineCallbacks - def check_user_in_room( + async def check_user_in_room( self, room_id: str, user_id: str, current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ): + ) -> EventBase: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -119,37 +118,35 @@ def check_user_in_room( Raises: AuthError if the user is/was not in the room. Returns: - Deferred[Optional[EventBase]]: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + Membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. """ if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield defer.ensureDeferred( - self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) + member = await self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) - membership = member.membership if member else None - if membership == Membership.JOIN: - return member + if member: + membership = member.membership - # XXX this looks totally bogus. Why do we not allow users who have been banned, - # or those who were members previously and have been re-invited? - if allow_departed_users and membership == Membership.LEAVE: - forgot = yield self.store.did_forget(user_id, room_id) - if not forgot: + if membership == Membership.JOIN: return member + # XXX this looks totally bogus. Why do we not allow users who have been banned, + # or those who were members previously and have been re-invited? + if allow_departed_users and membership == Membership.LEAVE: + forgot = await self.store.did_forget(user_id, room_id) + if not forgot: + return member + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - @defer.inlineCallbacks - def check_host_in_room(self, room_id, host): + async def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - latest_event_ids = yield self.store.is_host_joined(room_id, host) + latest_event_ids = await self.store.is_host_joined(room_id, host) return latest_event_ids def can_federate(self, event, auth_events): @@ -160,14 +157,13 @@ def can_federate(self, event, auth_events): def get_public_keys(self, invite_event): return event_auth.get_public_keys(invite_event) - @defer.inlineCallbacks - def get_user_by_req( + async def get_user_by_req( self, request: Request, allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, - ): + ) -> synapse.types.Requester: """ Get a registered user's ID. Args: @@ -180,7 +176,7 @@ def get_user_by_req( /login will deliver access tokens regardless of expiration. Returns: - defer.Deferred: resolves to a `synapse.types.Requester` object + Resolves to the requester Raises: InvalidClientCredentialsError if no user by that token exists or the token is invalid. @@ -194,14 +190,14 @@ def get_user_by_req( access_token = self.get_access_token_from_request(request) - user_id, app_service = yield self._get_appservice_user_id(request) + user_id, app_service = await self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) if ip_addr and self._track_appservice_user_ips: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user_id, access_token=access_token, ip=ip_addr, @@ -211,7 +207,7 @@ def get_user_by_req( return synapse.types.create_requester(user_id, app_service=app_service) - user_info = yield self.get_user_by_access_token( + user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) user = user_info["user"] @@ -221,7 +217,7 @@ def get_user_by_req( # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) if ( expiration_ts is not None and self.clock.time_msec() >= expiration_ts @@ -235,7 +231,7 @@ def get_user_by_req( device_id = user_info.get("device_id") if user and access_token and ip_addr: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user.to_string(), access_token=access_token, ip=ip_addr, @@ -261,8 +257,7 @@ def get_user_by_req( except KeyError: raise MissingClientTokenError() - @defer.inlineCallbacks - def _get_appservice_user_id(self, request): + async def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) @@ -283,14 +278,13 @@ def _get_appservice_user_id(self, request): if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") - if not (yield self.store.get_user_by_id(user_id)): + if not (await self.store.get_user_by_id(user_id)): raise AuthError(403, "Application service has not registered this user") return user_id, app_service - @defer.inlineCallbacks - def get_user_by_access_token( + async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ): + ) -> dict: """ Validate access token and get user_id from it Args: @@ -300,7 +294,7 @@ def get_user_by_access_token( allow_expired: If False, raises an InvalidClientTokenError if the token is expired Returns: - Deferred[dict]: dict that includes: + dict that includes: `user` (UserID) `is_guest` (bool) `token_id` (int|None): access token id. May be None if guest @@ -314,7 +308,7 @@ def get_user_by_access_token( if rights == "access": # first look in the database - r = yield self._look_up_user_by_access_token(token) + r = await self._look_up_user_by_access_token(token) if r: valid_until_ms = r["valid_until_ms"] if ( @@ -352,7 +346,7 @@ def get_user_by_access_token( # It would of course be much easier to store guest access # tokens in the database as well, but that would break existing # guest tokens. - stored_user = yield self.store.get_user_by_id(user_id) + stored_user = await self.store.get_user_by_id(user_id) if not stored_user: raise InvalidClientTokenError("Unknown user_id %s" % user_id) if not stored_user["is_guest"]: @@ -482,9 +476,8 @@ def _verify_expiry(self, caveat): now = self.hs.get_clock().time_msec() return now < expiry - @defer.inlineCallbacks - def _look_up_user_by_access_token(self, token): - ret = yield self.store.get_user_by_access_token(token) + async def _look_up_user_by_access_token(self, token): + ret = await self.store.get_user_by_access_token(token) if not ret: return None @@ -507,7 +500,7 @@ def get_appservice_by_req(self, request): logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender - return defer.succeed(service) + return service async def is_server_admin(self, user: UserID) -> bool: """ Check if the given user is a local server admin. @@ -522,7 +515,7 @@ async def is_server_admin(self, user: UserID) -> bool: def compute_auth_events( self, event, current_state_ids: StateMap[str], for_verification: bool = False, - ): + ) -> List[str]: """Given an event and current state return the list of event IDs used to auth an event. @@ -530,11 +523,11 @@ def compute_auth_events( should be added to the event's `auth_events`. Returns: - defer.Deferred(list[str]): List of event IDs. + List of event IDs. """ if event.type == EventTypes.Create: - return defer.succeed([]) + return [] # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding @@ -553,7 +546,7 @@ def compute_auth_events( if auth_ev_id: auth_ids.append(auth_ev_id) - return defer.succeed(auth_ids) + return auth_ids async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the @@ -636,10 +629,9 @@ def get_access_token_from_request(request: Request): return query_params[0].decode("ascii") - @defer.inlineCallbacks - def check_user_in_room_or_world_readable( + async def check_user_in_room_or_world_readable( self, room_id: str, user_id: str, allow_departed_users: bool = False - ): + ) -> Tuple[str, Optional[str]]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. @@ -650,10 +642,9 @@ def check_user_in_room_or_world_readable( members but have now departed Returns: - Deferred[tuple[str, str|None]]: Resolves to the current membership of - the user in the room and the membership event ID of the user. If - the user is not in the room and never has been, then - `(Membership.JOIN, None)` is returned. + Resolves to the current membership of the user in the room and the + membership event ID of the user. If the user is not in the room and + never has been, then `(Membership.JOIN, None)` is returned. """ try: @@ -662,15 +653,13 @@ def check_user_in_room_or_world_readable( # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.check_user_in_room( + member_event = await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield defer.ensureDeferred( - self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) + visibility = await self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" ) if ( visibility diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 5c499b6b4e66..49093bf18169 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved @@ -36,8 +34,7 @@ def __init__(self, hs): self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag @@ -60,7 +57,7 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): if user_id is not None: if user_id == self._server_notices_mxid: return - if (yield self.store.is_support_user(user_id)): + if await self.store.is_support_user(user_id): return if self._hs_disabled: @@ -76,11 +73,11 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): # If the user is already part of the MAU cohort or a trial user if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) + timestamp = await self.store.user_last_seen_monthly_active(user_id) if timestamp: return - is_trial = yield self.store.is_trial_user(user_id) + is_trial = await self.store.is_trial_user(user_id) if is_trial: return elif threepid: @@ -93,7 +90,7 @@ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): # allow registration. Support users are excluded from MAU checks. return # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() + current_mau = await self.store.get_monthly_active_count() if current_mau >= self._max_mau_value: raise ResourceLimitError( 403, diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b3bab1aa526c..6e40630ab6d4 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -238,14 +238,16 @@ class InteractiveAuthIncompleteError(Exception): (This indicates we should return a 401 with 'result' as the body) Attributes: + session_id: The ID of the ongoing interactive auth session. result: the server response to the request, which should be passed back to the client """ - def __init__(self, result: "JsonDict"): + def __init__(self, session_id: str, result: "JsonDict"): super(InteractiveAuthIncompleteError, self).__init__( "Interactive auth not yet complete" ) + self.session_id = session_id self.result = result diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f988f62a1e5f..7393d6cb741b 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -21,8 +21,6 @@ from canonicaljson import json from jsonschema import FormatChecker -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState @@ -137,9 +135,8 @@ def __init__(self, hs): super(Filtering, self).__init__() self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_user_filter(self, user_localpart, filter_id): - result = yield self.store.get_user_filter(user_localpart, filter_id) + async def get_user_filter(self, user_localpart, filter_id): + result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(result) def add_user_filter(self, user_localpart, user_filter): diff --git a/synapse/app/_base.py b/synapse/app/_base.py index fa40c68f535f..2b2cd795e072 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -268,7 +268,7 @@ def handle_sighup(*args, **kwargs): # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().db.start_profiling() + hs.get_datastore().db_pool.start_profiling() hs.get_pusherpool().start() setup_sentry(hs) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index c478df53be85..739b013d4c3a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -123,17 +123,18 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.server import HomeServer -from synapse.storage.data_stores.main.censor_events import CensorEventsStore -from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.data_stores.main.monthly_active_users import ( +from synapse.server import HomeServer, cache_in_self +from synapse.storage.databases.main.censor_events import CensorEventsStore +from synapse.storage.databases.main.media_repository import MediaRepositoryStore +from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) -from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.data_stores.main.search import SearchWorkerStore -from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.databases.main.presence import UserPresenceState +from synapse.storage.databases.main.search import SearchWorkerStore +from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore +from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree @@ -493,7 +494,10 @@ def _listen_http(self, listener_config: ListenerConfig): site_tag = listener_config.http_options.tag if site_tag is None: site_tag = port - resources = {} + + # We always include a health resource. + resources = {"/health": HealthResource()} + for res in listener_config.http_options.resources: for name in res.names: if name == "metrics": @@ -631,10 +635,12 @@ def start_listening(self, listeners: Iterable[ListenerConfig]): async def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_replication_data_handler(self): + @cache_in_self + def get_replication_data_handler(self): return GenericWorkerReplicationHandler(self) - def build_presence_handler(self): + @cache_in_self + def get_presence_handler(self): return GenericWorkerPresence(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b011e00b4b56..98d0d14a124b 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -68,6 +68,7 @@ from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource +from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer @@ -98,7 +99,9 @@ def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConf if site_tag is None: site_tag = port - resources = {} + # We always include a health resource. + resources = {"/health": HealthResource()} + for res in listener_config.http_options.resources: for name in res.names: if name == "openid" and "federation" in res.names: @@ -441,7 +444,7 @@ def start(): _base.start(hs, config.listeners) - hs.get_datastore().db.updates.start_doing_background_updates() + hs.get_datastore().db_pool.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) @@ -551,8 +554,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): # # This only reports info about the *main* database. - stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ - stats["database_server_version"] = hs.get_datastore().db.engine.server_version + stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: diff --git a/synapse/config/_util.py b/synapse/config/_util.py new file mode 100644 index 000000000000..cd31b1c3c9d0 --- /dev/null +++ b/synapse/config/_util.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List + +import jsonschema + +from synapse.config._base import ConfigError +from synapse.types import JsonDict + + +def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None: + """Validates a config setting against a JsonSchema definition + + This can be used to validate a section of the config file against a schema + definition. If the validation fails, a ConfigError is raised with a textual + description of the problem. + + Args: + json_schema: the schema to validate against + config: the configuration value to be validated + config_path: the path within the config file. This will be used as a basis + for the error message. + """ + try: + jsonschema.validate(config, json_schema) + except jsonschema.ValidationError as e: + # copy `config_path` before modifying it. + path = list(config_path) + for p in list(e.path): + if isinstance(p, int): + path.append("" % p) + else: + path.append(str(p)) + + raise ConfigError( + "Unable to parse configuration: %s at %s" % (e.message, ".".join(path)) + ) diff --git a/synapse/config/database.py b/synapse/config/database.py index 62bccd9ef52f..8a18a9ca2a7b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -100,7 +100,10 @@ def __init__(self, name: str, db_config: dict): self.name = name self.config = db_config - self.data_stores = data_stores + + # The `data_stores` config is actually talking about `databases` (we + # changed the name). + self.databases = data_stores class DatabaseConfig(Config): diff --git a/synapse/config/logger.py b/synapse/config/logger.py index dd775a97e884..c96e6ef62ac2 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -55,24 +55,33 @@ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ %(request)s - %(message)s' -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - handlers: file: - class: logging.handlers.RotatingFileHandler + class: logging.handlers.TimedRotatingFileHandler formatter: precise filename: ${log_file} - maxBytes: 104857600 - backupCount: 10 - filters: [context] + when: midnight + backupCount: 3 # Does not include the current log file. encoding: utf8 + + # Default to buffering writes to log file for efficiency. This means that + # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR + # logs will still be flushed immediately. + buffer: + class: logging.handlers.MemoryHandler + target: file + # The capacity is the number of log lines that are buffered before + # being written to disk. Increasing this will lead to better + # performance, at the expensive of it taking longer for log lines to + # be written to disk. + capacity: 10 + flushLevel: 30 # Flush for WARNING logs as well + + # A handler that writes logs to stderr. Unused by default, but can be used + # instead of "buffer" and "file" in the logger handlers. console: class: logging.StreamHandler formatter: precise - filters: [context] loggers: synapse.storage.SQL: @@ -80,9 +89,24 @@ # information such as access tokens. level: INFO + twisted: + # We send the twisted logging directly to the file handler, + # to work around https://github.com/matrix-org/synapse/issues/3471 + # when using "buffer" logger. Use "console" to log to stderr instead. + handlers: [file] + propagate: false + root: level: INFO - handlers: [file, console] + + # Write logs to the `buffer` handler, which will buffer them together in memory, + # then write them to a file. + # + # Replace "buffer" with "console" to log to stderr instead. (Note that you'll + # also need to update the configuation for the `twisted` logger above, in + # this case.) + # + handlers: [buffer] disable_existing_loggers: false """ @@ -168,11 +192,26 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): handler = logging.StreamHandler() handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) logger.addHandler(handler) else: logging.config.dictConfig(log_config) + # We add a log record factory that runs all messages through the + # LoggingContextFilter so that we get the context *at the time we log* + # rather than when we write to a handler. This can be done in config using + # filter options, but care must when using e.g. MemoryHandler to buffer + # writes. + + log_filter = LoggingContextFilter(request="") + old_factory = logging.getLogRecordFactory() + + def factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + log_filter.filter(record) + return record + + logging.setLogRecordFactory(factory) + # Route Twisted's native logging through to the standard library logging # system. observer = STDLibLogObserver() diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 147e6417ded4..036f8c0e9090 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -15,11 +15,15 @@ # limitations under the License. import logging +from typing import Any, List + +import attr from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +from ._util import validate_config logger = logging.getLogger(__name__) @@ -77,6 +81,11 @@ def read_config(self, config, **kwargs): self.saml2_enabled = True + attribute_requirements = saml2_config.get("attribute_requirements") or [] + self.attribute_requirements = _parse_attribute_requirements_def( + attribute_requirements + ) + self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) @@ -332,6 +341,17 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # #grandfathered_mxid_source_attribute: upn + # It is possible to configure Synapse to only allow logins if SAML attributes + # match particular values. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. + # + #attribute_requirements: + # - attribute: userGroup + # value: "staff" + # - attribute: department + # value: "sales" + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # @@ -359,3 +379,34 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): """ % { "config_dir_path": config_dir_path } + + +@attr.s(frozen=True) +class SamlAttributeRequirement: + """Object describing a single requirement for SAML attributes.""" + + attribute = attr.ib(type=str) + value = attr.ib(type=str) + + JSON_SCHEMA = { + "type": "object", + "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, + "required": ["attribute", "value"], + } + + +ATTRIBUTE_REQUIREMENTS_SCHEMA = { + "type": "array", + "items": SamlAttributeRequirement.JSON_SCHEMA, +} + + +def _parse_attribute_requirements_def( + attribute_requirements: Any, +) -> List[SamlAttributeRequirement]: + validate_config( + ATTRIBUTE_REQUIREMENTS_SCHEMA, + attribute_requirements, + config_path=["saml2_config", "attribute_requirements"], + ) + return [SamlAttributeRequirement(**x) for x in attribute_requirements] diff --git a/synapse/config/server.py b/synapse/config/server.py index 848587d2323c..9f15ed109e18 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -530,6 +530,21 @@ class LimitRemoteRoomsConfig(object): "request_token_inhibit_3pid_errors", False, ) + # List of users trialing the new experimental default push rules. This setting is + # not included in the sample configuration file on purpose as it's a temporary + # hack, so that some users can trial the new defaults without impacting every + # user on the homeserver. + users_new_default_push_rules = ( + config.get("users_new_default_push_rules") or [] + ) # type: list + if not isinstance(users_new_default_push_rules, list): + raise ConfigError("'users_new_default_push_rules' must be a list") + + # Turn the list into a set to improve lookup speed. + self.users_new_default_push_rules = set( + users_new_default_push_rules + ) # type: set + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index a5a2a7815d61..777c0f00b18d 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -48,6 +48,14 @@ class ServerContextFactory(ContextFactory): connections.""" def __init__(self, config): + # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, + # switch to those (see https://github.com/pyca/cryptography/issues/5379). + # + # note that, despite the confusing name, SSLv23_METHOD does *not* enforce SSLv2 + # or v3, but is a synonym for TLS_METHOD, which allows the client and server + # to negotiate an appropriate version of TLS constrained by the version options + # set with context.set_options. + # self._context = SSL.Context(SSL.SSLv23_METHOD) self.configure_context(self._context, config) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 69b53ca2bce0..9ed24380dd26 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,6 +17,7 @@ import attr from nacl.signing import SigningKey +from synapse.api.auth import Auth from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -27,6 +28,8 @@ ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.state import StateHandler +from synapse.storage.databases.main import DataStore from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @@ -42,45 +45,46 @@ class EventBuilder(object): Attributes: room_version: Version of the target room - room_id (str) - type (str) - sender (str) - content (dict) - unsigned (dict) - internal_metadata (_EventInternalMetadata) - - _state (StateHandler) - _auth (synapse.api.Auth) - _store (DataStore) - _clock (Clock) - _hostname (str): The hostname of the server creating the event + room_id + type + sender + content + unsigned + internal_metadata + + _state + _auth + _store + _clock + _hostname: The hostname of the server creating the event _signing_key: The signing key to use to sign the event as the server """ - _state = attr.ib() - _auth = attr.ib() - _store = attr.ib() - _clock = attr.ib() - _hostname = attr.ib() - _signing_key = attr.ib() + _state = attr.ib(type=StateHandler) + _auth = attr.ib(type=Auth) + _store = attr.ib(type=DataStore) + _clock = attr.ib(type=Clock) + _hostname = attr.ib(type=str) + _signing_key = attr.ib(type=SigningKey) room_version = attr.ib(type=RoomVersion) - room_id = attr.ib() - type = attr.ib() - sender = attr.ib() + room_id = attr.ib(type=str) + type = attr.ib(type=str) + sender = attr.ib(type=str) - content = attr.ib(default=attr.Factory(dict)) - unsigned = attr.ib(default=attr.Factory(dict)) + content = attr.ib(default=attr.Factory(dict), type=JsonDict) + unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict) # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. - _state_key = attr.ib(default=None) - _redacts = attr.ib(default=None) - _origin_server_ts = attr.ib(default=None) + _state_key = attr.ib(default=None, type=Optional[str]) + _redacts = attr.ib(default=None, type=Optional[str]) + _origin_server_ts = attr.ib(default=None, type=Optional[int]) internal_metadata = attr.ib( - default=attr.Factory(lambda: _EventInternalMetadata({})) + default=attr.Factory(lambda: _EventInternalMetadata({})), + type=_EventInternalMetadata, ) @property @@ -106,7 +110,7 @@ async def build(self, prev_event_ids): state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_ids = await self._auth.compute_auth_events(self, state_ids) + auth_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index cca93e3a4665..afecafe15c3e 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,7 +23,7 @@ from synapse.types import StateMap if TYPE_CHECKING: - from synapse.storage.data_stores.main import DataStore + from synapse.storage.databases.main import DataStore @attr.s(slots=True) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index dd150f89a6f0..8cbc23d901af 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -337,6 +337,28 @@ async def _transaction_transmission_loop(self) -> None: (e.retry_last_ts + e.retry_interval) / 1000.0 ), ) + + if e.retry_interval > 60 * 60 * 1000: + # we won't retry for another hour! + # (this suggests a significant outage) + # We drop pending PDUs and EDUs because otherwise they will + # rack up indefinitely. + # Note that: + # - the EDUs that are being dropped here are those that we can + # afford to drop (specifically, only typing notifications, + # read receipts and presence updates are being dropped here) + # - Other EDUs such as to_device messages are queued with a + # different mechanism + # - this is all volatile state that would be lost if the + # federation sender restarted anyway + + # dropping read receipts is a bit sad but should be solved + # through another mechanism, because this is all volatile! + self._pending_pdus = [] + self._pending_edus = [] + self._pending_edus_keyed = {} + self._pending_presence = {} + self._pending_rrs = {} except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 8280f8b9003d..c7f6cb3d73c3 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Tuple from canonicaljson import json @@ -54,7 +54,10 @@ def __init__(self, hs: "synapse.server.HomeServer"): @measure_func("_send_new_transaction") async def send_new_transaction( - self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] + self, + destination: str, + pending_pdus: List[Tuple[EventBase, int]], + pending_edus: List[Edu], ): # Make a transaction-sending opentracing span. This span follows on from diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index fbc56c351bce..c9044a501921 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -101,7 +101,7 @@ async def handle_event(event): async def start_scheduler(): try: - return self.scheduler.start() + return await self.scheduler.start() except Exception: logger.error("Application Services Failure") diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 39f29efb9690..0d7753353ea9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -161,7 +161,7 @@ async def validate_user_via_ui_auth( request_body: Dict[str, Any], clientip: str, description: str, - ) -> dict: + ) -> Tuple[dict, str]: """ Checks that the user is who they claim to be, via a UI auth. @@ -182,9 +182,14 @@ async def validate_user_via_ui_auth( describes the operation happening on their account. Returns: - The parameters for this request (which may + A tuple of (params, session_id). + + 'params' contains the parameters for this request (which may have been given only in a previous call). + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call + Raises: InteractiveAuthIncompleteError if the client has not yet completed any of the permitted login flows @@ -206,7 +211,7 @@ async def validate_user_via_ui_auth( flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = await self.check_auth( + result, params, session_id = await self.check_ui_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -229,7 +234,7 @@ async def validate_user_via_ui_auth( if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") - return params + return params, session_id def get_enabled_auth_types(self): """Return the enabled user-interactive authentication types @@ -239,7 +244,7 @@ def get_enabled_auth_types(self): """ return self.checkers.keys() - async def check_auth( + async def check_ui_auth( self, flows: List[List[str]], request: SynapseRequest, @@ -362,7 +367,7 @@ async def check_auth( if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session.session_id) + session.session_id, self._auth_dict_for_flows(flows, session.session_id) ) # check auth type currently being presented @@ -409,7 +414,7 @@ async def check_auth( ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError(ret) + raise InteractiveAuthIncompleteError(session.session_id, ret) async def add_oob_auth( self, stagetype: str, authdict: Dict[str, Any], clientip: str diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 71a89f09c765..1924636c4d70 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -57,13 +57,10 @@ async def get_stream( timeout=0, as_client_event=True, affect_presence=True, - only_keys=None, room_id=None, is_guest=False, ): """Fetches the events stream for a given user. - - If `only_keys` is not None, events from keys will be sent down. """ if room_id: @@ -93,7 +90,6 @@ async def get_stream( auth_user, pagin_config, timeout, - only_keys=only_keys, is_guest=is_guest, explicit_room_id=room_id, ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0d7d1adcea5b..593932adb788 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -71,7 +71,7 @@ ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room @@ -2064,7 +2064,7 @@ async def _prep_event( if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_x = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e451d6dc86cb..48b0fc7279be 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from canonicaljson import encode_canonical_json, json @@ -45,7 +45,7 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import ( Collection, @@ -93,11 +93,11 @@ def __init__(self, hs): async def get_room_data( self, - user_id: str = None, - room_id: str = None, - event_type: Optional[str] = None, - state_key: str = "", - is_guest: bool = False, + user_id: str, + room_id: str, + event_type: str, + state_key: str, + is_guest: bool, ) -> dict: """ Get data from a room. @@ -407,7 +407,7 @@ def __init__(self, hs: "HomeServer"): # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are @@ -707,7 +707,7 @@ async def deduplicate_state_event( async def create_and_send_nonmember_event( self, requester: Requester, - event_dict: EventBase, + event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, ) -> Tuple[EventBase, int]: @@ -768,6 +768,15 @@ async def create_new_client_event( else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + # we now ought to have some prev_events (unless it's a create event). + # + # do a quick sanity check here, rather than waiting until we've created the + # event and then try to auth it (which fails with a somewhat confusing "No + # create event in auth events") + assert ( + builder.type == EventTypes.Create or len(prev_event_ids) > 0 + ), "Attempting to create an event with no prev_events" + event = await builder.build(prev_event_ids=prev_event_ids) context = await self.state.compute_event_context(event) if requester: @@ -882,9 +891,7 @@ async def handle_new_client_event( except Exception: # Ensure that we actually remove the entries in the push actions # staging area, if we calculated them. - run_in_background( - self.store.remove_push_actions_from_staging, event.event_id - ) + await self.store.remove_push_actions_from_staging(event.event_id) raise async def _validate_canonical_alias( @@ -962,7 +969,7 @@ async def persist_and_notify_client_event( # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = set() + original_alt_aliases = [] # type: List[str] original_event_id = event.unsigned.get("replaces_state") if original_event_id: @@ -1010,6 +1017,10 @@ def is_inviter_member_event(e): current_state_ids = await context.get_current_state_ids() + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + state_to_include_ids = [ e_id for k, e_id in current_state_ids.items() @@ -1061,7 +1072,7 @@ def is_inviter_member_event(e): raise SynapseError(400, "Cannot redact event from a different room") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f87b823c39e6..4a107570fbfb 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import json import logging -from typing import Dict, Generic, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode import attr @@ -38,9 +38,11 @@ from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.server import HomeServer from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SESSION_COOKIE_NAME = b"oidc_session" @@ -90,7 +92,7 @@ class OidcHandler: """Handles requests related to the OpenID Connect login flow. """ - def __init__(self, hs: HomeServer): + def __init__(self, hs: "HomeServer"): self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b3a3bb8c3fd8..5387b3724fec 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -38,7 +38,7 @@ from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler -from synapse.storage.data_stores.main import DataStore +from synapse.storage.databases.main import DataStore from synapse.storage.presence import UserPresenceState from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer @@ -319,7 +319,7 @@ async def _on_shutdown(self): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.store.db.is_running(): + if not self.store.db_pool.is_running(): return logger.info( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8e409f24e89e..31705cdbdb7d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -16,7 +16,7 @@ import abc import logging from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -37,6 +37,10 @@ from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -48,7 +52,7 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -207,7 +211,7 @@ async def _local_membership_update( return duplicate.event_id, stream_id stream_id = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit + requester, event, context, extra_users=[target], ratelimit=ratelimit, ) prev_state_ids = await context.get_prev_state_ids() @@ -1000,7 +1004,7 @@ async def _remote_join( check_complexity = self.hs.config.limit_remote_rooms.enabled if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: - check_complexity = not await self.hs.auth.is_server_admin(user) + check_complexity = not await self.auth.is_server_admin(user) if check_complexity: # Fetch the room complexity diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 2d506dc1f2de..c1fcb9845472 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,15 +14,16 @@ # limitations under the License. import logging import re -from typing import Callable, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple import attr import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import SynapseError +from synapse.api.errors import AuthError, SynapseError from synapse.config import ConfigError +from synapse.config.saml2_config import SamlAttributeRequirement from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -34,6 +35,9 @@ from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) @@ -49,7 +53,7 @@ class Saml2SessionData: class SamlHandler: - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -62,6 +66,7 @@ def __init__(self, hs): self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) + self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements # plugin to do custom mapping from saml response to mxid self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( @@ -73,7 +78,7 @@ def __init__(self, hs): self._auth_provider_id = "saml" # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict = {} + self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) @@ -165,11 +170,18 @@ async def _map_saml_response_to_user( saml2.BINDING_HTTP_POST, outstanding=self._outstanding_requests_dict, ) + except saml2.response.UnsolicitedResponse as e: + # the pysaml2 library helpfully logs an ERROR here, but neglects to log + # the session ID. I don't really want to put the full text of the exception + # in the (user-visible) exception message, so let's log the exception here + # so we can track down the session IDs later. + logger.warning(str(e)) + raise SynapseError(400, "Unexpected SAML2 login.") except Exception as e: - raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) + raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,)) if saml2_auth.not_signed: - raise SynapseError(400, "SAML2 response was not signed") + raise SynapseError(400, "SAML2 response was not signed.") logger.debug("SAML2 response: %s", saml2_auth.origxml) for assertion in saml2_auth.assertions: @@ -188,6 +200,9 @@ async def _map_saml_response_to_user( saml2_auth.in_response_to, None ) + for requirement in self._saml2_attribute_requirements: + _check_attribute_requirement(saml2_auth.ava, requirement) + remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) @@ -294,6 +309,21 @@ def expire_sessions(self): del self._outstanding_requests_dict[reqid] +def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): + values = ava.get(req.attribute, []) + for v in values: + if v == req.value: + return + + logger.info( + "SAML2 attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + raise AuthError(403, "You are not authorized to log in here.") + + DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5a19bac92933..c42dac18f5f3 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,7 +103,6 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) - unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1887,10 +1886,6 @@ async def _generate_room_entry( if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] - - unread_count = await self.store.get_unread_message_count_for_user( - room_id, sync_config.user.to_string(), - ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1899,7 +1894,6 @@ async def _generate_room_entry( account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, - unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/http/client.py b/synapse/http/client.py index 529532a0638e..8aeb70cdecc1 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -297,7 +297,7 @@ async def request(self, method, uri, data=None, headers=None): outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) - logger.info("Sending request %s %s", method, redact_uri(uri)) + logger.debug("Sending request %s %s", method, redact_uri(uri)) with start_active_span( "outgoing-client-request", diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 0c0264801504..369bf9c2fc37 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -247,7 +247,7 @@ async def _do_connect(self, protocol_factory): port = server.port try: - logger.info("Connecting to %s:%i", host.decode("ascii"), port) + logger.debug("Connecting to %s:%i", host.decode("ascii"), port) endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2a6373937a6e..738be43f4602 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -29,10 +29,11 @@ from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IReactorPluggableNameResolver +from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse import synapse.metrics import synapse.util.retryutils @@ -74,7 +75,7 @@ _next_id = 1 -@attr.s +@attr.s(frozen=True) class MatrixFederationRequest(object): method = attr.ib() """HTTP method @@ -110,26 +111,52 @@ class MatrixFederationRequest(object): :type: str|None """ + uri = attr.ib(init=False, type=bytes) + """The URI of this request + """ + def __attrs_post_init__(self): global _next_id - self.txn_id = "%s-O-%s" % (self.method, _next_id) + txn_id = "%s-O-%s" % (self.method, _next_id) _next_id = (_next_id + 1) % (MAXINT - 1) + object.__setattr__(self, "txn_id", txn_id) + + destination_bytes = self.destination.encode("ascii") + path_bytes = self.path.encode("ascii") + if self.query: + query_bytes = encode_query_args(self.query) + else: + query_bytes = b"" + + # The object is frozen so we can pre-compute this. + uri = urllib.parse.urlunparse( + (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ) + object.__setattr__(self, "uri", uri) + def get_json(self): if self.json_callback: return self.json_callback() return self.json -async def _handle_json_response(reactor, timeout_sec, request, response): +async def _handle_json_response( + reactor: IReactorTime, + timeout_sec: float, + request: MatrixFederationRequest, + response: IResponse, + start_ms: int, +): """ Reads the JSON body of a response, with a timeout Args: - reactor (IReactor): twisted reactor, for the timeout - timeout_sec (float): number of seconds to wait for response to complete - request (MatrixFederationRequest): the request that triggered the response - response (IResponse): response to the request + reactor: twisted reactor, for the timeout + timeout_sec: number of seconds to wait for response to complete + request: the request that triggered the response + response: response to the request + start_ms: Timestamp when request was made Returns: dict: parsed JSON response @@ -143,23 +170,35 @@ async def _handle_json_response(reactor, timeout_sec, request, response): body = await make_deferred_yieldable(d) except TimeoutError as e: logger.warning( - "{%s} [%s] Timed out reading response", request.txn_id, request.destination, + "{%s} [%s] Timed out reading response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), ) raise RequestSendFailed(e, can_retry=True) from e except Exception as e: logger.warning( - "{%s} [%s] Error reading response: %s", + "{%s} [%s] Error reading response %s %s: %s", request.txn_id, request.destination, + request.method, + request.uri.decode("ascii"), e, ) raise + + time_taken_secs = reactor.seconds() - start_ms / 1000 + logger.info( - "{%s} [%s] Completed: %d %s", + "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), + time_taken_secs, + request.method, + request.uri.decode("ascii"), ) return body @@ -261,7 +300,9 @@ async def _send_request_with_optional_trailing_slash( # 'M_UNRECOGNIZED' which some endpoints can return when omitting a # trailing slash on Synapse <= v0.99.3. logger.info("Retrying request with trailing slash") - request.path += "/" + + # Request is frozen so we create a new instance + request = attr.evolve(request, path=request.path + "/") response = await self._send_request(request, **send_request_args) @@ -373,9 +414,7 @@ async def _send_request( else: retries_left = MAX_SHORT_RETRIES - url_bytes = urllib.parse.urlunparse( - (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") - ) + url_bytes = request.uri url_str = url_bytes.decode("ascii") url_to_sign_bytes = urllib.parse.urlunparse( @@ -402,7 +441,7 @@ async def _send_request( headers_dict[b"Authorization"] = auth_headers - logger.info( + logger.debug( "{%s} [%s] Sending request: %s %s; timeout %fs", request.txn_id, request.destination, @@ -436,7 +475,6 @@ async def _send_request( except DNSLookupError as e: raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: - logger.info("Failed to send request: %s", e) raise RequestSendFailed(e, can_retry=True) from e incoming_responses_counter.labels( @@ -496,7 +534,7 @@ async def _send_request( break except RequestSendFailed as e: - logger.warning( + logger.info( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -654,6 +692,8 @@ async def put_json( json=data, ) + start_ms = self.clock.time_msec() + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, @@ -664,7 +704,7 @@ async def put_json( ) body = await _handle_json_response( - self.reactor, self.default_timeout, request, response + self.reactor, self.default_timeout, request, response, start_ms ) return body @@ -720,6 +760,8 @@ async def post_json( method="POST", destination=destination, path=path, query=args, json=data ) + start_ms = self.clock.time_msec() + response = await self._send_request( request, long_retries=long_retries, @@ -733,7 +775,7 @@ async def post_json( _sec_timeout = self.default_timeout body = await _handle_json_response( - self.reactor, _sec_timeout, request, response + self.reactor, _sec_timeout, request, response, start_ms, ) return body @@ -786,6 +828,8 @@ async def get_json( method="GET", destination=destination, path=path, query=args ) + start_ms = self.clock.time_msec() + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, @@ -796,7 +840,7 @@ async def get_json( ) body = await _handle_json_response( - self.reactor, self.default_timeout, request, response + self.reactor, self.default_timeout, request, response, start_ms ) return body @@ -846,6 +890,8 @@ async def delete_json( method="DELETE", destination=destination, path=path, query=args ) + start_ms = self.clock.time_msec() + response = await self._send_request( request, long_retries=long_retries, @@ -854,7 +900,7 @@ async def delete_json( ) body = await _handle_json_response( - self.reactor, self.default_timeout, request, response + self.reactor, self.default_timeout, request, response, start_ms ) return body @@ -914,12 +960,14 @@ async def get_file( ) raise logger.info( - "{%s} [%s] Completed: %d %s [%d bytes]", + "{%s} [%s] Completed: %d %s [%d bytes] %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), length, + request.method, + request.uri.decode("ascii"), ) return (length, headers) diff --git a/synapse/http/server.py b/synapse/http/server.py index 94ab29974aa0..ffe6cfa09ee1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -25,7 +25,7 @@ from typing import Any, Callable, Dict, Tuple, Union import jinja2 -from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json +from canonicaljson import encode_canonical_json, encode_pretty_printed_json from twisted.internet import defer from twisted.python import failure @@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import preserve_fn from synapse.logging.opentracing import trace_servlet +from synapse.util import json_encoder from synapse.util.caches import intern_dict logger = logging.getLogger(__name__) @@ -538,7 +539,7 @@ def respond_with_json( # canonicaljson already encodes to bytes json_bytes = encode_canonical_json(json_object) else: - json_bytes = json.dumps(json_object).encode("utf-8") + json_bytes = json_encoder.encode(json_object).encode("utf-8") return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors) diff --git a/synapse/http/site.py b/synapse/http/site.py index 6f3b2258cc30..6e79b4782801 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -146,10 +146,9 @@ def processing(self): Returns a context manager; the correct way to use this is: - @defer.inlineCallbacks - def handle_request(request): + async def handle_request(request): with request.processing("FooServlet"): - yield really_handle_the_request() + await really_handle_the_request() Once the context manager is closed, the completion of the request will be logged, and the various metrics will be updated. @@ -287,7 +286,9 @@ def _finished_processing(self): # the connection dropped) code += "!" - self.site.access_logger.info( + log_level = logging.INFO if self._should_log_request() else logging.DEBUG + self.site.access_logger.log( + log_level, "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" ' %sB %s "%s %s %s" "%s" [%d dbevts]', @@ -315,6 +316,17 @@ def _finished_processing(self): except Exception as e: logger.warning("Failed to stop metrics: %r", e) + def _should_log_request(self) -> bool: + """Whether we should log at INFO that we processed the request. + """ + if self.path == b"/health": + return False + + if self.method == b"OPTIONS": + return False + + return True + class XForwardedForRequest(SynapseRequest): def __init__(self, *args, **kw): diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index a9269196b3ed..f766d16db601 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import threading -from asyncio import iscoroutine from functools import wraps from typing import TYPE_CHECKING, Dict, Optional, Set from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer -from twisted.python.failure import Failure from synapse.logging.context import LoggingContext, PreserveLoggingContext @@ -167,7 +166,7 @@ def update_metrics(self): ) -def run_as_background_process(desc, func, *args, **kwargs): +def run_as_background_process(desc: str, func, *args, **kwargs): """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -179,7 +178,7 @@ def run_as_background_process(desc, func, *args, **kwargs): normal synapse inlineCallbacks function). Args: - desc (str): a description for this background process type + desc: a description for this background process type func: a function, which may return a Deferred or a coroutine args: positional args for func kwargs: keyword args for func @@ -188,8 +187,7 @@ def run_as_background_process(desc, func, *args, **kwargs): follow the synapse logcontext rules. """ - @defer.inlineCallbacks - def run(): + async def run(): with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -203,29 +201,21 @@ def run(): try: result = func(*args, **kwargs) - # We probably don't have an ensureDeferred in our call stack to handle - # coroutine results, so we need to ensureDeferred here. - # - # But we need this check because ensureDeferred doesn't like being - # called on immediate values (as opposed to Deferreds or coroutines). - if iscoroutine(result): - result = defer.ensureDeferred(result) + if inspect.isawaitable(result): + result = await result - return (yield result) + return result except Exception: - # failure.Failure() fishes the original Failure out of our stack, and - # thus gives us a sensible stack trace. - f = Failure() - logger.error( - "Background process '%s' threw an exception", - desc, - exc_info=(f.type, f.value, f.getTracebackObject()), + logger.exception( + "Background process '%s' threw an exception", desc, ) finally: _background_process_in_flight_count.labels(desc).dec() with PreserveLoggingContext(): - return run() + # Note that we return a Deferred here so that it can be used in a + # looping_call and other places that expect a Deferred. + return defer.ensureDeferred(run()) def wrap_as_background_process(desc): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a7849cefa5cd..c2fb757d9a49 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,12 +194,16 @@ def invalidate_access_token(self, access_token): synapse.api.errors.AuthError: the access token is invalid """ # see if the access token corresponds to a device - user_info = yield self._auth.get_user_by_access_token(access_token) + user_info = yield defer.ensureDeferred( + self._auth.get_user_by_access_token(access_token) + ) device_id = user_info.get("device_id") user_id = user_info["user"].to_string() if device_id: # delete the device, which will also delete its access tokens - yield self._hs.get_device_handler().delete_device(user_id, device_id) + yield defer.ensureDeferred( + self._hs.get_device_handler().delete_device(user_id, device_id) + ) else: # no associated device. Just delete the access token. yield defer.ensureDeferred( @@ -219,7 +223,7 @@ def run_db_interaction(self, desc, func, *args, **kwargs): Returns: Deferred[object]: result of func """ - return self._store.db.runInteraction(desc, func, *args, **kwargs) + return self._store.db_pool.runInteraction(desc, func, *args, **kwargs) def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str diff --git a/synapse/notifier.py b/synapse/notifier.py index 22ab4a9da525..dfb096e589ad 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,18 @@ import logging from collections import namedtuple -from typing import Callable, Iterable, List, TypeVar +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter @@ -24,12 +35,14 @@ import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError +from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import StreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Collection, StreamToken, UserID from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -77,7 +90,13 @@ class _NotifierUserStream(object): so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user_id, rooms, current_token, time_now_ms): + def __init__( + self, + user_id: str, + rooms: Collection[str], + current_token: StreamToken, + time_now_ms: int, + ): self.user_id = user_id self.rooms = set(rooms) self.current_token = current_token @@ -93,13 +112,13 @@ def __init__(self, user_id, rooms, current_token, time_now_ms): with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - def notify(self, stream_key, stream_id, time_now_ms): + def notify(self, stream_key: str, stream_id: int, time_now_ms: int): """Notify any listeners for this user of a new event from an event source. Args: - stream_key(str): The stream the event came from. - stream_id(str): The new id for the stream the event came from. - time_now_ms(int): The current time in milliseconds. + stream_key: The stream the event came from. + stream_id: The new id for the stream the event came from. + time_now_ms: The current time in milliseconds. """ self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token @@ -112,7 +131,7 @@ def notify(self, stream_key, stream_id, time_now_ms): self.notify_deferred = ObservableDeferred(defer.Deferred()) noify_deferred.callback(self.current_token) - def remove(self, notifier): + def remove(self, notifier: "Notifier"): """ Remove this listener from all the indexes in the Notifier it knows about. """ @@ -123,10 +142,10 @@ def remove(self, notifier): notifier.user_to_user_stream.pop(self.user_id) - def count_listeners(self): + def count_listeners(self) -> int: return len(self.notify_deferred.observers()) - def new_listener(self, token): + def new_listener(self, token: StreamToken) -> _NotificationListener: """Returns a deferred that is resolved when there is a new token greater than the given token. @@ -159,14 +178,16 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs: "synapse.server.HomeServer"): - self.user_to_user_stream = {} - self.room_to_user_streams = {} + self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] + self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = [] + self.pending_new_room_events = ( + [] + ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -178,10 +199,9 @@ def __init__(self, hs: "synapse.server.HomeServer"): self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() + self.federation_sender = None if hs.should_send_federation(): self.federation_sender = hs.get_federation_sender() - else: - self.federation_sender = None self.state_handler = hs.get_state_handler() @@ -193,12 +213,12 @@ def __init__(self, hs: "synapse.server.HomeServer"): # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): - all_user_streams = set() + all_user_streams = set() # type: Set[_NotifierUserStream] - for x in list(self.room_to_user_streams.values()): - all_user_streams |= x - for x in list(self.user_to_user_stream.values()): - all_user_streams.add(x) + for streams in list(self.room_to_user_streams.values()): + all_user_streams |= streams + for stream in list(self.user_to_user_stream.values()): + all_user_streams.add(stream) return sum(stream.count_listeners() for stream in all_user_streams) @@ -223,7 +243,11 @@ def add_replication_callback(self, cb: Callable[[], None]): self.replication_callbacks.append(cb) def on_new_room_event( - self, event, room_stream_id, max_room_stream_id, extra_users=[] + self, + event: EventBase, + room_stream_id: int, + max_room_stream_id: int, + extra_users: Collection[Union[str, UserID]] = [], ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -241,11 +265,11 @@ def on_new_room_event( self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id): + def _notify_pending_new_room_events(self, max_room_stream_id: int): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id(int): The highest stream_id below which all + max_room_stream_id: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events @@ -258,7 +282,12 @@ def _notify_pending_new_room_events(self, max_room_stream_id): else: self._on_new_room_event(event, room_stream_id, extra_users) - def _on_new_room_event(self, event, room_stream_id, extra_users=[]): + def _on_new_room_event( + self, + event: EventBase, + room_stream_id: int, + extra_users: Collection[Union[str, UserID]] = [], + ): """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( @@ -275,13 +304,19 @@ def _on_new_room_event(self, event, room_stream_id, extra_users=[]): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - async def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id: int): try: await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") - def on_new_event(self, stream_key, new_token, users=[], rooms=[]): + def on_new_event( + self, + stream_key: str, + new_token: int, + users: Collection[Union[str, UserID]] = [], + rooms: Collection[str] = [], + ): """ Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. @@ -307,14 +342,19 @@ def on_new_event(self, stream_key, new_token, users=[], rooms=[]): self.notify_replication() - def on_new_replication_data(self): + def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" self.notify_replication() async def wait_for_events( - self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START - ): + self, + user_id: str, + timeout: int, + callback: Callable[[StreamToken, StreamToken], Awaitable[T]], + room_ids=None, + from_token=StreamToken.START, + ) -> T: """Wait until the callback returns a non empty response or the timeout fires. """ @@ -377,19 +417,16 @@ async def wait_for_events( async def get_events_for( self, - user, - pagination_config, - timeout, - only_keys=None, - is_guest=False, - explicit_room_id=None, - ): + user: UserID, + pagination_config: PaginationConfig, + timeout: int, + is_guest: bool = False, + explicit_room_id: str = None, + ) -> EventStreamResult: """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. - If `only_keys` is not None, events from keys will be sent down. - If explicit_room_id is not set, the user's joined rooms will be polled for events. If explicit_room_id is set, that room will be polled for events only if @@ -404,11 +441,13 @@ async def get_events_for( room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - async def check_for_updates(before_token, after_token): + async def check_for_updates( + before_token: StreamToken, after_token: StreamToken + ) -> EventStreamResult: if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) - events = [] + events = [] # type: List[EventBase] end_token = from_token for name, source in self.event_sources.sources.items(): @@ -417,8 +456,6 @@ async def check_for_updates(before_token, after_token): after_id = getattr(after_token, keyname) if before_id == after_id: continue - if only_keys and name not in only_keys: - continue new_events, new_key = await source.get_new_events( user=user, @@ -476,7 +513,9 @@ async def check_for_updates(before_token, after_token): return result - async def _get_room_ids(self, user, explicit_room_id): + async def _get_room_ids( + self, user: UserID, explicit_room_id: Optional[str] + ) -> Tuple[Collection[str], bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: @@ -486,7 +525,7 @@ async def _get_room_ids(self, user, explicit_room_id): raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - async def _is_world_readable(self, room_id): + async def _is_world_readable(self, room_id: str) -> bool: state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) @@ -496,7 +535,7 @@ async def _is_world_readable(self, room_id): return False @log_function - def remove_expired_streams(self): + def remove_expired_streams(self) -> None: time_now_ms = self.clock.time_msec() expired_streams = [] expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS @@ -510,21 +549,21 @@ def remove_expired_streams(self): expired_stream.remove(self) @log_function - def _register_with_keys(self, user_stream): + def _register_with_keys(self, user_stream: _NotifierUserStream): self.user_to_user_stream[user_stream.user_id] = user_stream for room in user_stream.rooms: s = self.room_to_user_streams.setdefault(room, set()) s.add(user_stream) - def _user_joined_room(self, user_id, room_id): + def _user_joined_room(self, user_id: str, room_id: str): new_user_stream = self.user_to_user_stream.get(user_id) if new_user_stream is not None: room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) - def notify_replication(self): + def notify_replication(self) -> None: """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 286374d0b537..8047873ff1d9 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,11 +19,13 @@ from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP -def list_with_base_rules(rawrules): +def list_with_base_rules(rawrules, use_new_defaults=False): """Combine the list of rules set by the user with the default push rules Args: rawrules(list): The rules the user has modified or set. + use_new_defaults(bool): Whether to use the new experimental default rules when + appending or prepending default rules. Returns: A new list with the rules set by the user combined with the defaults. @@ -43,7 +45,9 @@ def list_with_base_rules(rawrules): ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) @@ -54,6 +58,7 @@ def list_with_base_rules(rawrules): make_base_append_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, + use_new_defaults, ) ) current_prio_class -= 1 @@ -62,6 +67,7 @@ def list_with_base_rules(rawrules): make_base_prepend_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, + use_new_defaults, ) ) @@ -70,27 +76,39 @@ def list_with_base_rules(rawrules): while current_prio_class > 0: ruleslist.extend( make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) return ruleslist -def make_base_append_rules(kind, modified_base_rules): +def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": - rules = BASE_APPEND_OVERRIDE_RULES + rules = ( + NEW_APPEND_OVERRIDE_RULES + if use_new_defaults + else BASE_APPEND_OVERRIDE_RULES + ) elif kind == "underride": - rules = BASE_APPEND_UNDERRIDE_RULES + rules = ( + NEW_APPEND_UNDERRIDE_RULES + if use_new_defaults + else BASE_APPEND_UNDERRIDE_RULES + ) elif kind == "content": rules = BASE_APPEND_CONTENT_RULES @@ -105,7 +123,7 @@ def make_base_append_rules(kind, modified_base_rules): return rules -def make_base_prepend_rules(kind, modified_base_rules): +def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": @@ -270,6 +288,135 @@ def make_base_prepend_rules(kind, modified_base_rules): ] +NEW_APPEND_OVERRIDE_RULES = [ + { + "rule_id": "global/override/.m.rule.encrypted", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", + } + ], + "actions": ["notify"], + }, + { + "rule_id": "global/override/.m.rule.suppress_notices", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_suppress_notices_type", + }, + { + "kind": "event_match", + "key": "content.msgtype", + "pattern": "m.notice", + "_id": "_suppress_notices", + }, + ], + "actions": [], + }, + { + "rule_id": "global/underride/.m.rule.suppress_edits", + "conditions": [ + { + "kind": "event_match", + "key": "m.relates_to.m.rel_type", + "pattern": "m.replace", + "_id": "_suppress_edits", + } + ], + "actions": [], + }, + { + "rule_id": "global/override/.m.rule.invite_for_me", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", + }, + { + "kind": "event_match", + "key": "content.membership", + "pattern": "invite", + "_id": "_invite_member", + }, + {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, + ], + "actions": ["notify", {"set_tweak": "sound", "value": "default"}], + }, + { + "rule_id": "global/override/.m.rule.contains_display_name", + "conditions": [{"kind": "contains_display_name"}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + }, + { + "rule_id": "global/override/.m.rule.tombstone", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.tombstone", + "_id": "_tombstone", + }, + { + "kind": "event_match", + "key": "state_key", + "pattern": "", + "_id": "_tombstone_statekey", + }, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + }, + { + "rule_id": "global/override/.m.rule.roomnotif", + "conditions": [ + { + "kind": "event_match", + "key": "content.body", + "pattern": "@room", + "_id": "_roomnotif_content", + }, + { + "kind": "sender_notification_permission", + "key": "room", + "_id": "_roomnotif_pl", + }, + ], + "actions": [ + "notify", + {"set_tweak": "highlight"}, + {"set_tweak": "sound", "value": "default"}, + ], + }, + { + "rule_id": "global/override/.m.rule.call", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.call.invite", + "_id": "_call", + } + ], + "actions": ["notify", {"set_tweak": "sound", "value": "ring"}], + }, +] + + BASE_APPEND_UNDERRIDE_RULES = [ { "rule_id": "global/underride/.m.rule.call", @@ -354,6 +501,36 @@ def make_base_prepend_rules(kind, modified_base_rules): ] +NEW_APPEND_UNDERRIDE_RULES = [ + { + "rule_id": "global/underride/.m.rule.room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, + { + "kind": "event_match", + "key": "content.body", + "pattern": "*", + "_id": "body", + }, + ], + "actions": ["notify", {"set_tweak": "sound", "value": "default"}], + }, + { + "rule_id": "global/underride/.m.rule.message", + "conditions": [ + { + "kind": "event_match", + "key": "content.body", + "pattern": "*", + "_id": "body", + }, + ], + "actions": ["notify"], + "enabled": False, + }, +] + + BASE_RULE_IDS = set() for r in BASE_APPEND_CONTENT_RULES: @@ -375,3 +552,26 @@ def make_base_prepend_rules(kind, modified_base_rules): r["priority_class"] = PRIORITY_CLASS_MAP["underride"] r["default"] = True BASE_RULE_IDS.add(r["rule_id"]) + + +NEW_RULE_IDS = set() + +for r in BASE_APPEND_CONTENT_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["content"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in BASE_PREPEND_OVERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in NEW_APPEND_OVERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in NEW_APPEND_UNDERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["underride"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 04b9d8ac8265..e7fcee0e8701 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -120,7 +120,7 @@ async def _get_power_levels_and_sender_level(self, event, context): pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index bc8f71916b16..d0145666bfd9 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,13 +21,22 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + badge = len(invites) for room_id in joins: - unread_count = await store.get_unread_message_count_for_user(room_id, user_id) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if unread_count else 0 + if room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[room_id] + + notifs = await ( + store.get_unread_event_push_actions_by_room_for_user( + room_id, user_id, last_unread_event_id + ) + ) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if notifs["notify_count"] else 0 return badge diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f9e2533e9639..60f2e1245f99 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -16,8 +16,8 @@ import logging from typing import Optional -from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -25,7 +25,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = MultiWriterIdGenerator( diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 525b94fd87bc..154f0e687c58 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -17,13 +17,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore -from synapse.storage.data_stores.main.tags import TagsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data", diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index a67fbeffb779..0f8d7037bde1 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.databases.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 1a38f53dfb8b..a6fdedde6357 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) - def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index a8a16dbc711c..ee7f69a91816 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,14 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ToDeviceStream -from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_inbox", "stream_id" diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 9d8067342fd2..722f3745e9bc 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -16,14 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.data_stores.main.devices import DeviceWorkerStore -from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.devices import DeviceWorkerStore +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 8b9717c46fb8..1945bcf9a8d8 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.directory import DirectoryWorkerStore +from synapse.storage.databases.main.directory import DirectoryWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 1a1a50a24f07..da1cc836cf70 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,18 +15,18 @@ # limitations under the License. import logging -from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore -from synapse.storage.data_stores.main.event_push_actions import ( +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.relations import RelationsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.data_stores.main.signatures import SignatureWorkerStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore -from synapse.storage.data_stores.main.stream import StreamWorkerStore -from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore -from synapse.storage.database import Database +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -55,11 +55,11 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index bcb068895496..2562b6fc383f 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.filtering import FilteringStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 5d210fa3a1d9..3291558c7a76 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -16,13 +16,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream -from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index 3def367ae979..961579751cdf 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.keys import KeyStore +from synapse.storage.databases.main.keys import KeyStore # KeyStore isn't really safe to use from a worker, but for now we do so and hope that # the races it creates aren't too bad. diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 2938cb8e4326..a912c04360e1 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,8 +15,8 @@ from synapse.replication.tcp.streams import PresenceStream from synapse.storage import DataStore -from synapse.storage.data_stores.main.presence import PresenceStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -24,7 +24,7 @@ class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py index 28c508aad345..f85b20a07177 100644 --- a/synapse/replication/slave/storage/profile.py +++ b/synapse/replication/slave/storage/profile.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.storage.data_stores.main.profile import ProfileWorkerStore +from synapse.storage.databases.main.profile import ProfileWorkerStore class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 23ec1c5b112c..590187df4653 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.replication.tcp.streams import PushRulesStream -from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from .events import SlavedEventStore diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index ff449f36589b..63300e5da608 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,15 +15,15 @@ # limitations under the License. from synapse.replication.tcp.streams import PushersStream -from synapse.storage.data_stores.main.pusher import PusherWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.pusher import PusherWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 6982686eb512..17ba1f22ac47 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,15 +15,15 @@ # limitations under the License. from synapse.replication.tcp.streams import ReceiptsStream -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 4b8553e25030..a40f064e2b63 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.registration import RegistrationWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 8710207ada0b..427c81772b51 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,15 +14,15 @@ # limitations under the License. from synapse.replication.tcp.streams import PublicRoomsStream -from synapse.storage.data_stores.main.room import RoomWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.room import RoomWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index ac88e6b8c35b..2091ac0df67d 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.transactions import TransactionStore +from synapse.storage.databases.main.transactions import TransactionStore from ._base import BaseSlavedStore diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f33801f8838e..d853e4447eb6 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,11 +18,12 @@ allowed to be sent by which side. """ import abc -import json import logging from typing import Tuple, Type -_json_encoder = json.JSONEncoder() +from canonicaljson import json + +from synapse.util import json_encoder as _json_encoder logger = logging.getLogger(__name__) diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html index bfd6449c5d5e..01cd9bdaf3c5 100644 --- a/synapse/res/templates/saml_error.html +++ b/synapse/res/templates/saml_error.html @@ -2,10 +2,17 @@ - SSO error + SSO login error -

Oops! Something went wrong during authentication.

+{# a 403 means we have actively rejected their login #} +{% if code == 403 %} +

You are not allowed to log in here.

+{% else %} +

+ There was an error during authentication: +

+
{{ msg }}

If you are seeing this page after clicking a link sent to you via email, make sure you only click the confirmation link once, and that you open the @@ -37,9 +44,9 @@ // to print one. let errorDesc = new URLSearchParams(searchStr).get("error_description") if (errorDesc) { - - document.getElementById("errormsg").innerText = ` ("${errorDesc}")`; + document.getElementById("errormsg").innerText = errorDesc; } +{% endif %} - \ No newline at end of file + diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a8364d9793d7..7c292ef3f939 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -31,7 +31,7 @@ assert_user_is_admin, historical_admin_path_patterns, ) -from synapse.storage.data_stores.main.room import RoomSortOrder +from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import RoomAlias, RoomID, UserID, create_requester logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 5934b1fe8bdc..b210015173b7 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -89,7 +89,7 @@ async def on_DELETE(self, request, room_alias): dir_handler = self.handlers.directory_handler try: - service = await self.auth.get_appservice_by_req(request) + service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) await dir_handler.delete_appservice_association(service, room_alias) logger.info( diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 9fd490813693..00831879f387 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -25,7 +25,7 @@ parse_json_value_from_request, parse_string, ) -from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client.v2_alpha._base import client_patterns @@ -45,6 +45,8 @@ def __init__(self, hs): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None + self._users_new_default_push_rules = hs.config.users_new_default_push_rules + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -179,7 +181,12 @@ def set_rule_attr(self, user_id, spec, val): rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: + if user_id in self._users_new_default_push_rules: + rule_ids = NEW_RULE_IDS + else: + rule_ids = BASE_RULE_IDS + + if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) return self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 13c713e66509..d4006f2c1789 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,7 +18,12 @@ from http import HTTPStatus from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, +) from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( @@ -240,18 +245,12 @@ async def on_POST(self, request): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the new password provided to us. - if "new_password" in body: - new_password = body.pop("new_password") + new_password = body.pop("new_password", None) + if new_password is not None: if not isinstance(new_password, str) or len(new_password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(new_password) - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "new_password_hash" in body: - raise SynapseError(400, "Unexpected property: new_password_hash") - body["new_password_hash"] = await self.auth_handler.hash(new_password) - # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -264,23 +263,49 @@ async def on_POST(self, request): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) - params = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + params, session_id = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise user_id = requester.user.to_string() else: requester = None - result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + result, params, session_id = await self.auth_handler.check_ui_auth( + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] @@ -305,12 +330,21 @@ async def on_POST(self, request): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password_hash"]) - new_password_hash = params["new_password_hash"] + # If we have a password in this request, prefer it. Otherwise, there + # must be a password hash from an earlier request. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + else: + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + logout_devices = params.get("logout_devices", True) await self._set_password_handler.set_password( - user_id, new_password_hash, logout_devices, requester + user_id, password_hash, logout_devices, requester ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6bdacacb8124..7794c8a5be33 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -24,6 +24,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, + InteractiveAuthIncompleteError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -386,6 +387,7 @@ def __init__(self, hs): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self._registration_enabled = self.hs.config.enable_registration self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -411,20 +413,8 @@ async def on_POST(self, request): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) - # we do basic sanity checks here because the auth layer will store these - # in sessions. Pull out the username/password provided to us. - if "password" in body: - password = body.pop("password") - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(password) - - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "password_hash" in body: - raise SynapseError(400, "Unexpected property: password_hash") - body["password_hash"] = await self.auth_handler.hash(password) - + # Pull out the provided username and do basic sanity checks early since + # the auth layer will store these in sessions. desired_username = None if "username" in body: if not isinstance(body["username"], str) or len(body["username"]) > 512: @@ -433,7 +423,7 @@ async def on_POST(self, request): appservice = None if self.auth.has_access_token(request): - appservice = await self.auth.get_appservice_by_req(request) + appservice = self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -458,22 +448,35 @@ async def on_POST(self, request): ) return 200, result # we throw for non 200 responses - # for regular registration, downcase the provided username before - # attempting to register it. This should mean - # that people who try to register with upper-case in their usernames - # don't get a nasty surprise. (Note that we treat username - # case-insenstively in login, so they are free to carry on imagining - # that their username is CrAzYh4cKeR if that keeps them happy) - if desired_username is not None: - desired_username = desired_username.lower() - # == Normal User Registration == (everyone else) - if not self.hs.config.enable_registration: + if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled") + # For regular registration, convert the provided username to lowercase + # before attempting to register it. This should mean that people who try + # to register with upper-case in their usernames don't get a nasty surprise. + # + # Note that we treat usernames case-insensitively in login, so they are + # free to carry on imagining that their username is CrAzYh4cKeR if that + # keeps them happy. + if desired_username is not None: + desired_username = desired_username.lower() + + # Check if this account is upgrading from a guest account. guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password_hash" not in body: + # Pull out the provided password and do basic sanity checks early. + # + # Note that we remove the password from the body since the auth layer + # will store the body in the session and we don't want a plaintext + # password store there. + password = body.pop("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + if "initial_device_display_name" in body and password is None: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out @@ -483,6 +486,7 @@ async def on_POST(self, request): session_id = self.auth_handler.get_session_id(body) registered_user_id = None + password_hash = None if session_id: # if we get a registered user id out of here, it means we previously # registered a user for this session, so we could just return the @@ -491,7 +495,12 @@ async def on_POST(self, request): registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) + # Extract the previously-hashed password from the session. + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + # Ensure that the username is valid. if desired_username is not None: await self.registration_handler.check_username( desired_username, @@ -499,20 +508,38 @@ async def on_POST(self, request): assigned_user_id=registered_user_id, ) - auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", - ) + # Check if the user-interactive authentication flows are complete, if + # not this will raise a user-interactive auth error. + try: + auth_result, params, session_id = await self.auth_handler.check_ui_auth( + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth. + # + # Hash the password and store it with the session since the client + # is not required to provide the password again. + # + # If a password hash was previously stored we will not attempt to + # re-hash and store it for efficiency. This assumes the password + # does not change throughout the authentication flow, but this + # should be fine since the data is meant to be consistent. + if not password_hash and password: + password_hash = await self.auth_handler.hash(password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise # Check that we're not trying to register a denied 3pid. # # the user-facing checks will probably already have happened in # /register/email/requestToken when we requested a 3pid, but that's not # guaranteed. - if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: @@ -534,12 +561,15 @@ async def on_POST(self, request): # don't re-register the threepids registered = False else: - # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password_hash"]) + # If we have a password in this request, prefer it. Otherwise, there + # might be a password hash from an earlier request. + if password: + password_hash = await self.auth_handler.hash(password) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) desired_username = params.get("username", None) guest_access_token = params.get("guest_access_token", None) - new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -581,7 +611,7 @@ async def on_POST(self, request): registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password_hash=new_password_hash, + password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, @@ -594,8 +624,8 @@ async def on_POST(self, request): ): await self.store.upsert_monthly_active_user(registered_user_id) - # remember that we've now registered that user account, and with - # what user ID (since the user may not have specified) + # Remember that the user account has been registered (and the user + # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) @@ -634,7 +664,7 @@ async def _create_registration_details(self, user_id, params): (object) params: registration parameters, from which we pull device_id, initial_device_name and inhibit_login Returns: - defer.Deferred: (object) dictionary for response from /register + (object) dictionary for response from /register """ result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 3f5bf75e592e..a5c24fbd63da 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,7 +426,6 @@ def serialize(events): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 4386eb4e72ba..b3e4d5612ed2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -22,8 +22,6 @@ import jinja2 from jinja2 import TemplateNotFound -from twisted.internet import defer - from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError from synapse.http.server import DirectServeHtmlResource, respond_with_html @@ -135,7 +133,7 @@ async def _async_render_GET(self, request): else: qualified_user_id = UserID(username, self.hs.hostname).to_string() - u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id) + u = await self.store.get_user_by_id(qualified_user_id) if u is None: raise NotFoundError("Unknown user") diff --git a/synapse/rest/health.py b/synapse/rest/health.py new file mode 100644 index 000000000000..0170950bf382 --- /dev/null +++ b/synapse/rest/health.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource + + +class HealthResource(Resource): + """A resource that does nothing except return a 200 with a body of `OK`, + which can be used as a health check. + + Note: `SynapseRequest._should_log_request` ensures that requests to + `/health` do not get logged at INFO. + """ + + isLeaf = 1 + + def render_GET(self, request): + request.setHeader(b"Content-Type", b"text/plain") + return b"OK" diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e12f65a20649..cd8c246594cf 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -27,9 +27,7 @@ from urllib import parse as urlparse import attr -from canonicaljson import json -from twisted.internet import defer from twisted.internet.error import DNSLookupError from synapse.api.errors import Codes, SynapseError @@ -43,6 +41,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string @@ -228,7 +227,7 @@ async def _async_render_GET(self, request): else: logger.info("Returning cached response") - og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) + og = await make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True) async def _do_preview(self, url: str, user: str, ts: int) -> bytes: @@ -355,7 +354,7 @@ async def _do_preview(self, url: str, user: str, ts: int) -> bytes: logger.debug("Calculated OG for %s as %s", url, og) - jsonog = json.dumps(og) + jsonog = json_encoder.encode(og) # store OG in history-aware DB cache await self.store.store_url_cache( @@ -586,7 +585,7 @@ async def _expire_url_cache_data(self): logger.debug("Running url preview cache expiry") - if not (await self.store.db.updates.has_completed_background_updates()): + if not (await self.store.db_pool.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/secrets.py b/synapse/secrets.py index 5f43f81eb0fd..ff86950a5472 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -25,8 +25,12 @@ if sys.version_info[0:2] >= (3, 6): import secrets - def Secrets(): - return secrets + class Secrets: + def token_bytes(self, nbytes=32): + return secrets.token_bytes(nbytes) + + def token_hex(self, nbytes=32): + return secrets.token_hex(nbytes) else: diff --git a/synapse/server.py b/synapse/server.py index 8e4111253008..9055b97ac317 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -22,10 +22,14 @@ # Imports required for the default HomeServer() implementation import abc +import functools import logging import os +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +import twisted from twisted.mail.smtp import sendmail +from twisted.web.iweb import IPolicyForHTTPS from synapse.api.auth import Auth from synapse.api.filtering import Filtering @@ -93,7 +97,7 @@ from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -105,32 +109,74 @@ WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStore, DataStores, Storage +from synapse.storage import Databases, DataStore, Storage from synapse.streams.events import EventSources +from synapse.types import DomainSpecificString from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from synapse.handlers.oidc_handler import OidcHandler + from synapse.handlers.saml_handler import SamlHandler -class HomeServer(object): + +T = TypeVar("T", bound=Callable[..., Any]) + + +def cache_in_self(builder: T) -> T: + """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and + returning if so. If not, calls the given function and sets `self.foo` to it. + + Also ensures that dependency cycles throw an exception correctly, rather + than overflowing the stack. + """ + + if not builder.__name__.startswith("get_"): + raise Exception( + "@cache_in_self can only be used on functions starting with `get_`" + ) + + depname = builder.__name__[len("get_") :] + + building = [False] + + @functools.wraps(builder) + def _get(self): + try: + return getattr(self, depname) + except AttributeError: + pass + + # Prevent cyclic dependencies from deadlocking + if building[0]: + raise ValueError("Cyclic dependency while building %s" % (depname,)) + + building[0] = True + try: + dep = builder(self) + setattr(self, depname, dep) + finally: + building[0] = False + + return dep + + # We cast here as we need to tell mypy that `_get` has the same signature as + # `builder`. + return cast(T, _get) + + +class HomeServer(metaclass=abc.ABCMeta): """A basic homeserver object without lazy component builders. This will need all of the components it requires to either be passed as constructor arguments, or the relevant methods overriding to create them. Typically this would only be used for unit tests. - For every dependency in the DEPENDENCIES list below, this class creates one - method, - def get_DEPENDENCY(self) - which returns the value of that dependency. If no value has yet been set - nor was provided to the constructor, it will attempt to call a lazy builder - method called - def build_DEPENDENCY(self) - which must be implemented by the subclass. This code may call any of the - required "get" methods on the instance to obtain the sub-dependencies that - one requires. + Dependencies should be added by creating a `def get_(self)` + function, wrapping it in `@cache_in_self`. Attributes: config (synapse.config.homeserver.HomeserverConfig): @@ -138,86 +184,6 @@ def build_DEPENDENCY(self) we are listening on to provide HTTP services. """ - __metaclass__ = abc.ABCMeta - - DEPENDENCIES = [ - "http_client", - "federation_client", - "federation_server", - "handlers", - "auth", - "room_creation_handler", - "room_shutdown_handler", - "state_handler", - "state_resolution_handler", - "presence_handler", - "sync_handler", - "typing_handler", - "room_list_handler", - "acme_handler", - "auth_handler", - "device_handler", - "stats_handler", - "e2e_keys_handler", - "e2e_room_keys_handler", - "event_handler", - "event_stream_handler", - "initial_sync_handler", - "application_service_api", - "application_service_scheduler", - "application_service_handler", - "device_message_handler", - "profile_handler", - "event_creation_handler", - "deactivate_account_handler", - "set_password_handler", - "notifier", - "event_sources", - "keyring", - "pusherpool", - "event_builder_factory", - "filtering", - "http_client_context_factory", - "simple_http_client", - "proxied_http_client", - "media_repository", - "media_repository_resource", - "federation_transport_client", - "federation_sender", - "receipts_handler", - "macaroon_generator", - "tcp_replication", - "read_marker_handler", - "action_generator", - "user_directory_handler", - "groups_local_handler", - "groups_server_handler", - "groups_attestation_signing", - "groups_attestation_renewer", - "secrets", - "spam_checker", - "third_party_event_rules", - "room_member_handler", - "federation_registry", - "server_notices_manager", - "server_notices_sender", - "message_handler", - "pagination_handler", - "room_context_handler", - "sendmail", - "registration_handler", - "account_validity_handler", - "cas_handler", - "saml_handler", - "oidc_handler", - "event_client_serializer", - "password_policy_handler", - "storage", - "replication_streamer", - "replication_data_handler", - "replication_streams", - ] - REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] # This is overridden in derived application classes @@ -232,16 +198,17 @@ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwar config: The full config for the homeserver. """ if not reactor: - from twisted.internet import reactor + from twisted.internet import reactor as _reactor + + reactor = _reactor self._reactor = reactor self.hostname = hostname # the key we use to sign events and requests self.signing_key = config.key.signing_key[0] self.config = config - self._building = {} - self._listening_services = [] - self.start_time = None + self._listening_services = [] # type: List[twisted.internet.tcp.Port] + self.start_time = None # type: Optional[int] self._instance_id = random_string(5) self._instance_name = config.worker_name or "master" @@ -255,13 +222,13 @@ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwar burst_count=config.rc_registration.burst_count, ) - self.datastores = None + self.datastores = None # type: Optional[Databases] # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname]) - def get_instance_id(self): + def get_instance_id(self) -> str: """A unique ID for this synapse process instance. This is used to distinguish running instances in worker-based @@ -277,13 +244,13 @@ def get_instance_name(self) -> str: """ return self._instance_name - def setup(self): + def setup(self) -> None: logger.info("Setting up.") self.start_time = int(self.get_clock().time()) - self.datastores = DataStores(self.DATASTORE_CLASS, self) + self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") - def setup_master(self): + def setup_master(self) -> None: """ Some handlers have side effects on instantiation (like registering background updates). This function causes them to be fetched, and @@ -292,192 +259,242 @@ def setup_master(self): for i in self.REQUIRED_ON_MASTER_STARTUP: getattr(self, "get_" + i)() - def get_reactor(self): + def get_reactor(self) -> twisted.internet.base.ReactorBase: """ Fetch the Twisted reactor in use by this HomeServer. """ return self._reactor - def get_ip_from_request(self, request): + def get_ip_from_request(self, request) -> str: # X-Forwarded-For is handled by our custom request type. return request.getClientIP() - def is_mine(self, domain_specific_string): + def is_mine(self, domain_specific_string: DomainSpecificString) -> bool: return domain_specific_string.domain == self.hostname - def is_mine_id(self, string): + def is_mine_id(self, string: str) -> bool: return string.split(":", 1)[1] == self.hostname - def get_clock(self): + def get_clock(self) -> Clock: return self.clock def get_datastore(self) -> DataStore: + if not self.datastores: + raise Exception("HomeServer.setup must be called before getting datastores") + return self.datastores.main - def get_datastores(self): + def get_datastores(self) -> Databases: + if not self.datastores: + raise Exception("HomeServer.setup must be called before getting datastores") + return self.datastores - def get_config(self): + def get_config(self) -> HomeServerConfig: return self.config - def get_distributor(self): + def get_distributor(self) -> Distributor: return self.distributor def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def build_federation_client(self): + @cache_in_self + def get_federation_client(self) -> FederationClient: return FederationClient(self) - def build_federation_server(self): + @cache_in_self + def get_federation_server(self) -> FederationServer: return FederationServer(self) - def build_handlers(self): + @cache_in_self + def get_handlers(self) -> Handlers: return Handlers(self) - def build_notifier(self): + @cache_in_self + def get_notifier(self) -> Notifier: return Notifier(self) - def build_auth(self): + @cache_in_self + def get_auth(self) -> Auth: return Auth(self) - def build_http_client_context_factory(self): + @cache_in_self + def get_http_client_context_factory(self) -> IPolicyForHTTPS: return ( InsecureInterceptableContextFactory() if self.config.use_insecure_ssl_client_just_for_testing_do_not_use else RegularPolicyForHTTPS() ) - def build_simple_http_client(self): + @cache_in_self + def get_simple_http_client(self) -> SimpleHttpClient: return SimpleHttpClient(self) - def build_proxied_http_client(self): + @cache_in_self + def get_proxied_http_client(self) -> SimpleHttpClient: return SimpleHttpClient( self, http_proxy=os.getenvb(b"http_proxy"), https_proxy=os.getenvb(b"HTTPS_PROXY"), ) - def build_room_creation_handler(self): + @cache_in_self + def get_room_creation_handler(self) -> RoomCreationHandler: return RoomCreationHandler(self) - def build_room_shutdown_handler(self): + @cache_in_self + def get_room_shutdown_handler(self) -> RoomShutdownHandler: return RoomShutdownHandler(self) - def build_sendmail(self): + @cache_in_self + def get_sendmail(self) -> sendmail: return sendmail - def build_state_handler(self): + @cache_in_self + def get_state_handler(self) -> StateHandler: return StateHandler(self) - def build_state_resolution_handler(self): + @cache_in_self + def get_state_resolution_handler(self) -> StateResolutionHandler: return StateResolutionHandler(self) - def build_presence_handler(self): + @cache_in_self + def get_presence_handler(self) -> PresenceHandler: return PresenceHandler(self) - def build_typing_handler(self): + @cache_in_self + def get_typing_handler(self): if self.config.worker.writers.typing == self.get_instance_name(): return TypingWriterHandler(self) else: return FollowerTypingHandler(self) - def build_sync_handler(self): + @cache_in_self + def get_sync_handler(self) -> SyncHandler: return SyncHandler(self) - def build_room_list_handler(self): + @cache_in_self + def get_room_list_handler(self) -> RoomListHandler: return RoomListHandler(self) - def build_auth_handler(self): + @cache_in_self + def get_auth_handler(self) -> AuthHandler: return AuthHandler(self) - def build_macaroon_generator(self): + @cache_in_self + def get_macaroon_generator(self) -> MacaroonGenerator: return MacaroonGenerator(self) - def build_device_handler(self): + @cache_in_self + def get_device_handler(self): if self.config.worker_app: return DeviceWorkerHandler(self) else: return DeviceHandler(self) - def build_device_message_handler(self): + @cache_in_self + def get_device_message_handler(self) -> DeviceMessageHandler: return DeviceMessageHandler(self) - def build_e2e_keys_handler(self): + @cache_in_self + def get_e2e_keys_handler(self) -> E2eKeysHandler: return E2eKeysHandler(self) - def build_e2e_room_keys_handler(self): + @cache_in_self + def get_e2e_room_keys_handler(self) -> E2eRoomKeysHandler: return E2eRoomKeysHandler(self) - def build_acme_handler(self): + @cache_in_self + def get_acme_handler(self) -> AcmeHandler: return AcmeHandler(self) - def build_application_service_api(self): + @cache_in_self + def get_application_service_api(self) -> ApplicationServiceApi: return ApplicationServiceApi(self) - def build_application_service_scheduler(self): + @cache_in_self + def get_application_service_scheduler(self) -> ApplicationServiceScheduler: return ApplicationServiceScheduler(self) - def build_application_service_handler(self): + @cache_in_self + def get_application_service_handler(self) -> ApplicationServicesHandler: return ApplicationServicesHandler(self) - def build_event_handler(self): + @cache_in_self + def get_event_handler(self) -> EventHandler: return EventHandler(self) - def build_event_stream_handler(self): + @cache_in_self + def get_event_stream_handler(self) -> EventStreamHandler: return EventStreamHandler(self) - def build_initial_sync_handler(self): + @cache_in_self + def get_initial_sync_handler(self) -> InitialSyncHandler: return InitialSyncHandler(self) - def build_profile_handler(self): + @cache_in_self + def get_profile_handler(self): if self.config.worker_app: return BaseProfileHandler(self) else: return MasterProfileHandler(self) - def build_event_creation_handler(self): + @cache_in_self + def get_event_creation_handler(self) -> EventCreationHandler: return EventCreationHandler(self) - def build_deactivate_account_handler(self): + @cache_in_self + def get_deactivate_account_handler(self) -> DeactivateAccountHandler: return DeactivateAccountHandler(self) - def build_set_password_handler(self): + @cache_in_self + def get_set_password_handler(self) -> SetPasswordHandler: return SetPasswordHandler(self) - def build_event_sources(self): + @cache_in_self + def get_event_sources(self) -> EventSources: return EventSources(self) - def build_keyring(self): + @cache_in_self + def get_keyring(self) -> Keyring: return Keyring(self) - def build_event_builder_factory(self): + @cache_in_self + def get_event_builder_factory(self) -> EventBuilderFactory: return EventBuilderFactory(self) - def build_filtering(self): + @cache_in_self + def get_filtering(self) -> Filtering: return Filtering(self) - def build_pusherpool(self): + @cache_in_self + def get_pusherpool(self) -> PusherPool: return PusherPool(self) - def build_http_client(self): + @cache_in_self + def get_http_client(self) -> MatrixFederationHttpClient: tls_client_options_factory = context_factory.FederationPolicyForHTTPS( self.config ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_media_repository_resource(self): + @cache_in_self + def get_media_repository_resource(self) -> MediaRepositoryResource: # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of return MediaRepositoryResource(self) - def build_media_repository(self): + @cache_in_self + def get_media_repository(self) -> MediaRepository: return MediaRepository(self) - def build_federation_transport_client(self): + @cache_in_self + def get_federation_transport_client(self) -> TransportLayerClient: return TransportLayerClient(self) - def build_federation_sender(self): + @cache_in_self + def get_federation_sender(self): if self.should_send_federation(): return FederationSender(self) elif not self.config.worker_app: @@ -485,156 +502,152 @@ def build_federation_sender(self): else: raise Exception("Workers cannot send federation traffic") - def build_receipts_handler(self): + @cache_in_self + def get_receipts_handler(self) -> ReceiptsHandler: return ReceiptsHandler(self) - def build_read_marker_handler(self): + @cache_in_self + def get_read_marker_handler(self) -> ReadMarkerHandler: return ReadMarkerHandler(self) - def build_tcp_replication(self): + @cache_in_self + def get_tcp_replication(self) -> ReplicationCommandHandler: return ReplicationCommandHandler(self) - def build_action_generator(self): + @cache_in_self + def get_action_generator(self) -> ActionGenerator: return ActionGenerator(self) - def build_user_directory_handler(self): + @cache_in_self + def get_user_directory_handler(self) -> UserDirectoryHandler: return UserDirectoryHandler(self) - def build_groups_local_handler(self): + @cache_in_self + def get_groups_local_handler(self): if self.config.worker_app: return GroupsLocalWorkerHandler(self) else: return GroupsLocalHandler(self) - def build_groups_server_handler(self): + @cache_in_self + def get_groups_server_handler(self): if self.config.worker_app: return GroupsServerWorkerHandler(self) else: return GroupsServerHandler(self) - def build_groups_attestation_signing(self): + @cache_in_self + def get_groups_attestation_signing(self) -> GroupAttestationSigning: return GroupAttestationSigning(self) - def build_groups_attestation_renewer(self): + @cache_in_self + def get_groups_attestation_renewer(self) -> GroupAttestionRenewer: return GroupAttestionRenewer(self) - def build_secrets(self): + @cache_in_self + def get_secrets(self) -> Secrets: return Secrets() - def build_stats_handler(self): + @cache_in_self + def get_stats_handler(self) -> StatsHandler: return StatsHandler(self) - def build_spam_checker(self): + @cache_in_self + def get_spam_checker(self): return SpamChecker(self) - def build_third_party_event_rules(self): + @cache_in_self + def get_third_party_event_rules(self) -> ThirdPartyEventRules: return ThirdPartyEventRules(self) - def build_room_member_handler(self): + @cache_in_self + def get_room_member_handler(self): if self.config.worker_app: return RoomMemberWorkerHandler(self) return RoomMemberMasterHandler(self) - def build_federation_registry(self): + @cache_in_self + def get_federation_registry(self) -> FederationHandlerRegistry: return FederationHandlerRegistry(self) - def build_server_notices_manager(self): + @cache_in_self + def get_server_notices_manager(self): if self.config.worker_app: raise Exception("Workers cannot send server notices") return ServerNoticesManager(self) - def build_server_notices_sender(self): + @cache_in_self + def get_server_notices_sender(self): if self.config.worker_app: return WorkerServerNoticesSender(self) return ServerNoticesSender(self) - def build_message_handler(self): + @cache_in_self + def get_message_handler(self) -> MessageHandler: return MessageHandler(self) - def build_pagination_handler(self): + @cache_in_self + def get_pagination_handler(self) -> PaginationHandler: return PaginationHandler(self) - def build_room_context_handler(self): + @cache_in_self + def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) - def build_registration_handler(self): + @cache_in_self + def get_registration_handler(self) -> RegistrationHandler: return RegistrationHandler(self) - def build_account_validity_handler(self): + @cache_in_self + def get_account_validity_handler(self) -> AccountValidityHandler: return AccountValidityHandler(self) - def build_cas_handler(self): + @cache_in_self + def get_cas_handler(self) -> CasHandler: return CasHandler(self) - def build_saml_handler(self): + @cache_in_self + def get_saml_handler(self) -> "SamlHandler": from synapse.handlers.saml_handler import SamlHandler return SamlHandler(self) - def build_oidc_handler(self): + @cache_in_self + def get_oidc_handler(self) -> "OidcHandler": from synapse.handlers.oidc_handler import OidcHandler return OidcHandler(self) - def build_event_client_serializer(self): + @cache_in_self + def get_event_client_serializer(self) -> EventClientSerializer: return EventClientSerializer(self) - def build_password_policy_handler(self): + @cache_in_self + def get_password_policy_handler(self) -> PasswordPolicyHandler: return PasswordPolicyHandler(self) - def build_storage(self) -> Storage: - return Storage(self, self.datastores) + @cache_in_self + def get_storage(self) -> Storage: + return Storage(self, self.get_datastores()) - def build_replication_streamer(self) -> ReplicationStreamer: + @cache_in_self + def get_replication_streamer(self) -> ReplicationStreamer: return ReplicationStreamer(self) - def build_replication_data_handler(self): + @cache_in_self + def get_replication_data_handler(self) -> ReplicationDataHandler: return ReplicationDataHandler(self) - def build_replication_streams(self): + @cache_in_self + def get_replication_streams(self) -> Dict[str, Stream]: return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} - def remove_pusher(self, app_id, push_key, user_id): - return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): + return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) - def should_send_federation(self): + def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" return self.config.send_federation and ( not self.config.worker_app or self.config.worker_app == "synapse.app.federation_sender" ) - - -def _make_dependency_method(depname): - def _get(hs): - try: - return getattr(hs, depname) - except AttributeError: - pass - - try: - builder = getattr(hs, "build_%s" % (depname)) - except AttributeError: - raise NotImplementedError( - "%s has no %s nor a builder for it" % (type(hs).__name__, depname) - ) - - # Prevent cyclic dependencies from deadlocking - if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % (depname,)) - - hs._building[depname] = 1 - try: - dep = builder() - setattr(hs, depname, dep) - finally: - del hs._building[depname] - - return dep - - setattr(HomeServer, "get_%s" % (depname), _get) - - -# Build magic accessors for every dependency -for depname in HomeServer.DEPENDENCIES: - _make_dependency_method(depname) diff --git a/synapse/server.pyi b/synapse/server.pyi deleted file mode 100644 index 1aba408c2164..000000000000 --- a/synapse/server.pyi +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Dict - -import twisted.internet - -import synapse.api.auth -import synapse.config.homeserver -import synapse.crypto.keyring -import synapse.federation.federation_server -import synapse.federation.sender -import synapse.federation.transport.client -import synapse.handlers -import synapse.handlers.auth -import synapse.handlers.deactivate_account -import synapse.handlers.device -import synapse.handlers.e2e_keys -import synapse.handlers.message -import synapse.handlers.presence -import synapse.handlers.register -import synapse.handlers.room -import synapse.handlers.room_member -import synapse.handlers.set_password -import synapse.http.client -import synapse.http.matrixfederationclient -import synapse.notifier -import synapse.push.pusherpool -import synapse.replication.tcp.client -import synapse.replication.tcp.handler -import synapse.rest.media.v1.media_repository -import synapse.server_notices.server_notices_manager -import synapse.server_notices.server_notices_sender -import synapse.state -import synapse.storage -from synapse.events.builder import EventBuilderFactory -from synapse.handlers.typing import FollowerTypingHandler -from synapse.replication.tcp.streams import Stream - -class HomeServer(object): - @property - def config(self) -> synapse.config.homeserver.HomeServerConfig: - pass - @property - def hostname(self) -> str: - pass - def get_auth(self) -> synapse.api.auth.Auth: - pass - def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: - pass - def get_datastore(self) -> synapse.storage.DataStore: - pass - def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: - pass - def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: - pass - def get_handlers(self) -> synapse.handlers.Handlers: - pass - def get_state_handler(self) -> synapse.state.StateHandler: - pass - def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: - pass - def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: - """Fetch an HTTP client implementation which doesn't do any blacklisting - or support any HTTP_PROXY settings""" - pass - def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: - """Fetch an HTTP client implementation which doesn't do any blacklisting - but does support HTTP_PROXY settings""" - pass - def get_deactivate_account_handler( - self, - ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: - pass - def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: - pass - def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: - pass - def get_room_shutdown_handler(self) -> synapse.handlers.room.RoomShutdownHandler: - pass - def get_event_creation_handler( - self, - ) -> synapse.handlers.message.EventCreationHandler: - pass - def get_set_password_handler( - self, - ) -> synapse.handlers.set_password.SetPasswordHandler: - pass - def get_federation_sender(self) -> synapse.federation.sender.FederationSender: - pass - def get_federation_transport_client( - self, - ) -> synapse.federation.transport.client.TransportLayerClient: - pass - def get_media_repository_resource( - self, - ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: - pass - def get_media_repository( - self, - ) -> synapse.rest.media.v1.media_repository.MediaRepository: - pass - def get_server_notices_manager( - self, - ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: - pass - def get_server_notices_sender( - self, - ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: - pass - def get_notifier(self) -> synapse.notifier.Notifier: - pass - def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler: - pass - def get_clock(self) -> synapse.util.Clock: - pass - def get_reactor(self) -> twisted.internet.base.ReactorBase: - pass - def get_keyring(self) -> synapse.crypto.keyring.Keyring: - pass - def get_tcp_replication( - self, - ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: - pass - def get_replication_data_handler( - self, - ) -> synapse.replication.tcp.client.ReplicationDataHandler: - pass - def get_federation_registry( - self, - ) -> synapse.federation.federation_server.FederationHandlerRegistry: - pass - def is_mine_id(self, domain_id: str) -> bool: - pass - def get_instance_id(self) -> str: - pass - def get_instance_name(self) -> str: - pass - def get_event_builder_factory(self) -> EventBuilderFactory: - pass - def get_storage(self) -> synapse.storage.Storage: - pass - def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler: - pass - def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator: - pass - def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: - pass - def get_replication_streams(self) -> Dict[str, Stream]: - pass - def get_http_client( - self, - ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient: - pass - def should_send_federation(self) -> bool: - pass - def get_typing_handler(self) -> FollowerTypingHandler: - pass diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 25ccef5aa5d5..a1d388466734 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -28,7 +28,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo from synapse.types import StateMap from synapse.util import Clock diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ec89f645d401..5ef38535593b 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -17,18 +17,19 @@ """ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple -databases). The `Database` class represents a single physical database. The -`data_stores` are classes that talk directly to a `Database` instance and have -associated schemas, background updates, etc. On top of those there are classes -that provide high level interfaces that combine calls to multiple `data_stores`. +databases). The `DatabasePool` class represents connections to a single physical +database. The `databases` are classes that talk directly to a `DatabasePool` +instance and have associated schemas, background updates, etc. On top of those +there are classes that provide high level interfaces that combine calls to +multiple `databases`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from synapse.storage.data_stores import DataStores -from synapse.storage.data_stores.main import DataStore +from synapse.storage.databases import Databases +from synapse.storage.databases.main import DataStore from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage @@ -40,7 +41,7 @@ class Storage(object): """The high level interfaces for talking to various storage layers. """ - def __init__(self, hs, stores: DataStores): + def __init__(self, hs, stores: Databases): # We include the main data store here mainly so that we don't have to # rewrite all the existing code to split it into high vs low level # interfaces. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 985a04286961..6814bf5fcf1e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -23,7 +23,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.types import Collection, get_domain_from_id logger = logging.getLogger(__name__) @@ -37,11 +37,11 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine - self.db = database + self.db_pool = database self.rand = random.SystemRandom() def process_replication_rows(self, stream_name, instance_name, token, rows): @@ -58,7 +58,6 @@ def _invalidate_state_caches(self, room_id, members_changed): """ for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) - self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 018826ef6947..f43463df53b8 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -88,7 +88,7 @@ class BackgroundUpdater(object): def __init__(self, hs, database): self._clock = hs.get_clock() - self.db = database + self.db_pool = database # if a background update is currently running, its name. self._current_background_update = None # type: Optional[str] @@ -139,7 +139,7 @@ async def has_completed_background_updates(self) -> bool: # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = await self.db.simple_select_onecol( + updates = await self.db_pool.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -160,7 +160,7 @@ async def has_completed_background_update(self, update_name) -> bool: if update_name == self._current_background_update: return False - update_exists = await self.db.simple_select_one_onecol( + update_exists = await self.db_pool.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -189,10 +189,10 @@ def get_background_updates_txn(txn): ORDER BY ordering, update_name """ ) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) if not self._current_background_update: - all_pending_updates = await self.db.runInteraction( + all_pending_updates = await self.db_pool.runInteraction( "background_updates", get_background_updates_txn, ) if not all_pending_updates: @@ -243,7 +243,7 @@ async def _do_background_update(self, desired_duration_ms: float) -> int: else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = await self.db.simple_select_one_onecol( + progress_json = await self.db_pool.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -402,7 +402,7 @@ def create_index_sqlite(conn): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.db.engine, engines.PostgresEngine): + if isinstance(self.db_pool.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None @@ -413,7 +413,7 @@ def create_index_sqlite(conn): def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.db.runWithConnection(runner) + yield self.db_pool.runWithConnection(runner) yield self._end_background_update(update_name) return 1 @@ -433,7 +433,7 @@ def _end_background_update(self, update_name): % update_name ) self._current_background_update = None - return self.db.simple_delete_one( + return self.db_pool.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -445,7 +445,7 @@ def _background_update_progress(self, update_name: str, progress: dict): progress: The progress of the update. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, @@ -463,7 +463,7 @@ def _background_update_progress_txn(self, txn, update_name, progress): progress_json = json.dumps(progress) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql deleted file mode 100644 index 531b532c7387..000000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store a boolean value in the events table for whether the event should be counted in --- the unread_count property of sync responses. -ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ce8757a40000..4ada6f556327 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -279,7 +279,7 @@ def interval(self, interval_duration_secs, limit=3): return top_n_counters -class Database(object): +class DatabasePool(object): """Wraps a single physical database and connection pool. A single database may be used by multiple data stores. diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/databases/__init__.py similarity index 69% rename from synapse/storage/data_stores/__init__.py rename to synapse/storage/databases/__init__.py index 599ee470d423..4406e5827341 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -15,17 +15,17 @@ import logging -from synapse.storage.data_stores.main.events import PersistEventsStore -from synapse.storage.data_stores.state import StateGroupDataStore -from synapse.storage.database import Database, make_conn +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.events import PersistEventsStore +from synapse.storage.databases.state import StateGroupDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database logger = logging.getLogger(__name__) -class DataStores(object): - """The various data stores. +class Databases(object): + """The various databases. These are low level interfaces to physical databases. @@ -38,9 +38,9 @@ def __init__(self, main_store_class, hs): # store. self.databases = [] - self.main = None - self.state = None - self.persist_events = None + main = None + state = None + persist_events = None for database_config in hs.config.database.databases: db_name = database_config.name @@ -51,37 +51,35 @@ def __init__(self, main_store_class, hs): engine.check_database(db_conn) prepare_database( - db_conn, engine, hs.config, data_stores=database_config.data_stores, + db_conn, engine, hs.config, databases=database_config.databases, ) - database = Database(hs, database_config, engine) + database = DatabasePool(hs, database_config, engine) - if "main" in database_config.data_stores: + if "main" in database_config.databases: logger.info("Starting 'main' data store") # Sanity check we don't try and configure the main store on # multiple databases. - if self.main: + if main: raise Exception("'main' data store already configured") - self.main = main_store_class(database, db_conn, hs) + main = main_store_class(database, db_conn, hs) # If we're on a process that can persist events also # instantiate a `PersistEventsStore` if hs.config.worker.writers.events == hs.get_instance_name(): - self.persist_events = PersistEventsStore( - hs, database, self.main - ) + persist_events = PersistEventsStore(hs, database, main) - if "state" in database_config.data_stores: + if "state" in database_config.databases: logger.info("Starting 'state' data store") # Sanity check we don't try and configure the state store on # multiple databases. - if self.state: + if state: raise Exception("'state' data store already configured") - self.state = StateGroupDataStore(database, db_conn, hs) + state = StateGroupDataStore(database, db_conn, hs) db_conn.commit() @@ -90,8 +88,14 @@ def __init__(self, main_store_class, hs): logger.info("Database %r prepared", db_name) # Sanity check that we have actually configured all the required stores. - if not self.main: + if not main: raise Exception("No 'main' data store configured") - if not self.state: + if not state: raise Exception("No 'main' data store configured") + + # We use local variables here to ensure that the databases do not have + # optional types. + self.main = main + self.state = state + self.persist_events = persist_events diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/databases/main/__init__.py similarity index 95% rename from synapse/storage/data_stores/main/__init__.py rename to synapse/storage/databases/main/__init__.py index 932458f651eb..17fa47091950 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -21,7 +21,7 @@ from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( IdGenerator, @@ -119,7 +119,7 @@ class DataStore( CacheInvalidationWorkerStore, ServerMetricsStore, ): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine @@ -174,7 +174,7 @@ def __init__(self, database: Database, db_conn, hs): self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self.db.get_cache_dict( + presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -188,7 +188,7 @@ def __init__(self, database: Database, db_conn, hs): ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -203,7 +203,7 @@ def __init__(self, database: Database, db_conn, hs): ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -229,7 +229,7 @@ def __init__(self, database: Database, db_conn, hs): ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -243,7 +243,7 @@ def __init__(self, database: Database, db_conn, hs): prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -282,7 +282,7 @@ def _get_active_presence(self, db_conn): txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) txn.close() for row in rows: @@ -295,7 +295,9 @@ def count_daily_users(self): Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.db.runInteraction("count_daily_users", self._count_users, yesterday) + return self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) def count_monthly_users(self): """ @@ -305,7 +307,7 @@ def count_monthly_users(self): amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.db.runInteraction( + return self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -405,7 +407,7 @@ def _count_r30_users(txn): return results - return self.db.runInteraction("count_r30_users", _count_r30_users) + return self.db_pool.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -470,7 +472,7 @@ def _generate_user_daily_visits(txn): # frequently self._last_user_visit_update = now - return self.db.runInteraction( + return self.db_pool.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) @@ -481,7 +483,7 @@ def get_users(self): Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self.db.simple_select_list( + return self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ @@ -543,10 +545,12 @@ def get_users_paginate_txn(txn): where_clause ) txn.execute(sql, args) - users = self.db.cursor_to_dict(txn) + users = self.db_pool.cursor_to_dict(txn) return users, count - return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) + return self.db_pool.runInteraction( + "get_users_paginate_txn", get_users_paginate_txn + ) def search_users(self, term): """Function to search users list for one or more users with @@ -558,7 +562,7 @@ def search_users(self, term): Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self.db.simple_search_list( + return self.db_pool.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/databases/main/account_data.py similarity index 82% rename from synapse/storage/data_stores/main/account_data.py rename to synapse/storage/databases/main/account_data.py index 33cc372dfd7e..82aac2bbf3a3 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,16 +16,16 @@ import abc import logging -from typing import List, Tuple - -from canonicaljson import json +from typing import List, Optional, Tuple from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.types import JsonDict +from synapse.util import json_encoder +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max @@ -69,7 +69,7 @@ def get_account_data_for_user(self, user_id): """ def get_account_data_for_user_txn(txn): - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -80,7 +80,7 @@ def get_account_data_for_user_txn(txn): row["account_data_type"]: db_to_json(row["content"]) for row in rows } - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -94,17 +94,19 @@ def get_account_data_for_user_txn(txn): return global_account_data, by_room - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, max_entries=5000) - def get_global_account_data_by_type_for_user(self, data_type, user_id): + @cached(num_args=2, max_entries=5000) + async def get_global_account_data_by_type_for_user( + self, data_type: str, user_id: str + ) -> Optional[JsonDict]: """ Returns: - Deferred: A dict + The account data. """ - result = yield self.db.simple_select_one_onecol( + result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -129,7 +131,7 @@ def get_account_data_for_room(self, user_id, room_id): """ def get_account_data_for_room_txn(txn): - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -140,7 +142,7 @@ def get_account_data_for_room_txn(txn): row["account_data_type"]: db_to_json(row["content"]) for row in rows } - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @@ -158,7 +160,7 @@ def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type """ def get_account_data_for_room_and_type_txn(txn): - content_json = self.db.simple_select_one_onecol_txn( + content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -172,7 +174,7 @@ def get_account_data_for_room_and_type_txn(txn): return db_to_json(content_json) if content_json else None - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -202,7 +204,7 @@ def get_updated_global_account_data_txn(txn): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn ) @@ -232,7 +234,7 @@ def get_updated_room_account_data_txn(txn): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn ) @@ -277,13 +279,15 @@ def get_updated_account_data_for_user_txn(txn): if not changed: return defer.succeed(({}, {})) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( + @cached(num_args=2, cache_context=True, max_entries=5000) + async def is_ignored_by( + self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext + ) -> bool: + ignored_account_data = await self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, @@ -295,7 +299,7 @@ def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", @@ -308,32 +312,35 @@ def __init__(self, database: Database, db_conn, hs): super(AccountDataStore, self).__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream Returns: - A deferred int. + The maximum stream ID. """ return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -351,7 +358,7 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content) # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -360,26 +367,28 @@ def add_account_data_to_room(self, user_id, room_id, account_data_type, content) (user_id, room_id, account_data_type), content ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_for_user(self, user_id, account_data_type, content): + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -397,7 +406,7 @@ def add_account_data_for_user(self, user_id, account_data_type, content): # Note: This is only here for backwards compat to allow admins to # roll back to a previous Synapse version. Next time we update the # database version we can remove this table. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -405,14 +414,13 @@ def add_account_data_for_user(self, user_id, account_data_type, content): (account_data_type, user_id) ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_max_stream_id(self, next_id): + def _update_max_stream_id(self, next_id: int): """Update the max stream_id Args: - next_id(int): The the revision to advance to. + next_id: The the revision to advance to. """ # Note: This is only here for backwards compat to allow admins to @@ -427,4 +435,4 @@ def _update(txn): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.db.runInteraction("update_account_data_max_stream_id", _update) + return self.db_pool.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/databases/main/appservice.py similarity index 89% rename from synapse/storage/data_stores/main/appservice.py rename to synapse/storage/databases/main/appservice.py index 56659fed37d9..5cf1a8839950 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -18,13 +18,11 @@ from canonicaljson import json -from twisted.internet import defer - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore logger = logging.getLogger(__name__) @@ -49,7 +47,7 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): + async def get_appservices_by_state(self, state): """Get a list of application services based on their state. Args: state(ApplicationServiceState): The state to filter on. Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. + A list of ApplicationServices, which may be empty. """ - results = yield self.db.simple_select_list( + results = await self.db_pool.simple_select_list( "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -147,16 +143,15 @@ def get_appservices_by_state(self, state): services.append(service) return services - @defer.inlineCallbacks - def get_appservice_state(self, service): + async def get_appservice_state(self, service): """Get the application service state. Args: service(ApplicationService): The service whose state to set. Returns: - A Deferred which resolves to ApplicationServiceState. + An ApplicationServiceState. """ - result = yield self.db.simple_select_one( + result = await self.db_pool.simple_select_one( "application_services_state", {"as_id": service.id}, ["state"], @@ -176,7 +171,7 @@ def set_appservice_state(self, service, state): Returns: A Deferred which resolves when the state was set successfully. """ - return self.db.simple_upsert( + return self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) @@ -217,7 +212,9 @@ def _create_appservice_txn(txn): ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) + return self.db_pool.runInteraction( + "create_appservice_txn", _create_appservice_txn + ) def complete_appservice_txn(self, txn_id, service): """Completes an application service transaction. @@ -250,7 +247,7 @@ def _complete_appservice_txn(txn): ) # Set current txn_id for AS to 'txn_id' - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, "application_services_state", {"as_id": service.id}, @@ -258,26 +255,24 @@ def _complete_appservice_txn(txn): ) # Delete txn - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "application_services_txns", {"txn_id": txn_id, "as_id": service.id}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "complete_appservice_txn", _complete_appservice_txn ) - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): + async def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. Args: service(ApplicationService): The app service to get the oldest txn. Returns: - A Deferred which resolves to an AppServiceTransaction or - None. + An AppServiceTransaction or None. """ def _get_oldest_unsent_txn(txn): @@ -288,7 +283,7 @@ def _get_oldest_unsent_txn(txn): " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return None @@ -296,7 +291,7 @@ def _get_oldest_unsent_txn(txn): return entry - entry = yield self.db.runInteraction( + entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -305,7 +300,7 @@ def _get_oldest_unsent_txn(txn): event_ids = db_to_json(entry["event_ids"]) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -326,12 +321,11 @@ def set_appservice_last_pos_txn(txn): "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice(self, current_id, limit): """Get all new evnets""" def get_new_events_for_appservice_txn(txn): @@ -355,11 +349,11 @@ def get_new_events_for_appservice_txn(txn): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/databases/main/cache.py similarity index 97% rename from synapse/storage/data_stores/main/cache.py rename to synapse/storage/databases/main/cache.py index edc3624fed6a..10de4460651c 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -26,7 +26,7 @@ EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -39,7 +39,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -92,7 +92,7 @@ def get_all_updated_caches_txn(txn): return updates, upto_token, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_updated_caches", get_all_updated_caches_txn ) @@ -172,7 +172,6 @@ def _invalidate_caches_for_event( self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: @@ -203,7 +202,7 @@ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, .. return cache_func.invalidate(keys) - await self.db.runInteraction( + await self.db_pool.runInteraction( "invalidate_cache_and_stream", self._send_invalidation_to_replication, cache_func.__name__, @@ -288,7 +287,7 @@ def _send_invalidation_to_replication( if keys is not None: keys = list(keys) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="cache_invalidation_stream_by_instance", values={ diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/databases/main/censor_events.py similarity index 88% rename from synapse/storage/data_stores/main/censor_events.py rename to synapse/storage/databases/main/censor_events.py index 2d4826172457..f211ddbaf88e 100644 --- a/synapse/storage/data_stores/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -16,15 +16,13 @@ import logging from typing import TYPE_CHECKING -from twisted.internet import defer - from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore -from synapse.storage.data_stores.main.events import encode_json -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.events import encode_json +from synapse.storage.databases.main.events_worker import EventsWorkerStore if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,7 +32,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): - def __init__(self, database: Database, db_conn, hs: "HomeServer"): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) def _censor_redactions(): @@ -56,7 +54,7 @@ async def _censor_redactions(self): return if not ( - await self.db.updates.has_completed_background_update( + await self.db_pool.updates.has_completed_background_update( "redactions_have_censored_ts_idx" ) ): @@ -85,7 +83,7 @@ async def _censor_redactions(self): LIMIT ? """ - rows = await self.db.execute( + rows = await self.db_pool.execute( "_censor_redactions_fetch", None, sql, before_ts, 100 ) @@ -123,14 +121,14 @@ def _update_censor_txn(txn): if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, updatevalues={"have_censored": True}, ) - await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) def _censor_event_txn(self, txn, event_id, pruned_json): """Censor an event by replacing its JSON in the event_json table with the @@ -141,24 +139,23 @@ def _censor_event_txn(self, txn, event_id, pruned_json): event_id (str): The ID of the event to censor. pruned_json (str): The pruned JSON """ - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="event_json", keyvalues={"event_id": event_id}, updatevalues={"json": pruned_json}, ) - @defer.inlineCallbacks - def expire_event(self, event_id): + async def expire_event(self, event_id: str) -> None: """Retrieve and expire an event that has expired, and delete its associated expiry timestamp. If the event can't be retrieved, delete its associated timestamp so we don't try to expire it again in the future. Args: - event_id (str): The ID of the event to delete. + event_id: The ID of the event to delete. """ # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) + event = await self.get_event(event_id) def delete_expired_event_txn(txn): # Delete the expiry timestamp associated with this event from the database. @@ -193,7 +190,9 @@ def delete_expired_event_txn(txn): txn, "_get_event_cache", (event.event_id,) ) - yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + await self.db_pool.runInteraction( + "delete_expired_event", delete_expired_event_txn + ) def _delete_event_expiry_txn(self, txn, event_id): """Delete the expiry timestamp associated with an event ID without deleting the @@ -203,6 +202,6 @@ def _delete_event_expiry_txn(self, txn, event_id): txn (LoggingTransaction): The transaction to use to perform the deletion. event_id (str): The event ID to delete the associated expiry timestamp of. """ - return self.db.simple_delete_txn( + return self.db_pool.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/databases/main/client_ips.py similarity index 86% rename from synapse/storage/data_stores/main/client_ips.py rename to synapse/storage/databases/main/client_ips.py index 995d4764a9e0..216a5925fc37 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -14,12 +14,11 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database, make_tuple_comparison_clause +from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -31,40 +30,40 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -73,28 +72,28 @@ def __init__(self, database: Database, db_conn, hs): ) # Drop the old non-unique index - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # Update the last seen info in devices. - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "devices_last_seen", self._devices_last_seen_update ) - @defer.inlineCallbacks - def _remove_user_ip_nonunique(self, progress, batch_size): + async def _remove_user_ip_nonunique(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.db.runWithConnection(f) - yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( + "user_ips_drop_nonunique_index" + ) return 1 - @defer.inlineCallbacks - def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress, batch_size): # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. @@ -104,14 +103,13 @@ def _analyze_user_ip(self, progress, batch_size): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) + await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) - yield self.db.updates._end_background_update("user_ips_analyze") + await self.db_pool.updates._end_background_update("user_ips_analyze") return 1 - @defer.inlineCallbacks - def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress, batch_size): # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they @@ -138,7 +136,7 @@ def get_last_seen(txn): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.db.runInteraction( + end_last_seen = await self.db_pool.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -269,19 +267,18 @@ def remove(txn): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.db.runInteraction("user_ips_dups_remove", remove) + await self.db_pool.runInteraction("user_ips_dups_remove", remove) if last: - yield self.db.updates._end_background_update("user_ips_remove_dupes") + await self.db_pool.updates._end_background_update("user_ips_remove_dupes") return batch_size - @defer.inlineCallbacks - def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update(self, progress, batch_size): """Background update to insert last seen info into devices table """ @@ -336,7 +333,7 @@ def _devices_last_seen_update_txn(txn): txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "devices_last_seen", {"last_user_id": user_id, "last_device_id": device_id}, @@ -344,18 +341,18 @@ def _devices_last_seen_update_txn(txn): return len(rows) - updated = yield self.db.runInteraction( + updated = await self.db_pool.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) if not updated: - yield self.db.updates._end_background_update("devices_last_seen") + await self.db_pool.updates._end_background_update("devices_last_seen") return updated class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 @@ -378,8 +375,7 @@ def __init__(self, database: Database, db_conn, hs): if self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - @defer.inlineCallbacks - def insert_client_ip( + async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): if not now: @@ -390,7 +386,7 @@ def insert_client_ip( last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - yield self.populate_monthly_active_users(user_id) + await self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return @@ -403,18 +399,18 @@ def insert_client_ip( def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.db.is_running(): + if not self.db_pool.is_running(): return to_update = self._batch_row_update self._batch_row_update = {} - return self.db.runInteraction( + return self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self.db._unsafe_to_upsert_tables or ( + if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): self.database_engine.lock_table(txn, "user_ips") @@ -423,7 +419,7 @@ def _update_client_ips_batch_txn(self, txn, to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -445,7 +441,7 @@ def _update_client_ips_batch_txn(self, txn, to_update): # this is always an update rather than an upsert: the row should # already exist, and if it doesn't, that may be because it has been # deleted, and we don't want to re-create it. - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -459,25 +455,25 @@ def _update_client_ips_batch_txn(self, txn, to_update): # Failed to upsert, log and continue logger.error("Failed to insert client IP %r: %r", entry, e) - @defer.inlineCallbacks - def get_last_client_ip_by_device(self, user_id, device_id): + async def get_last_client_ip_by_device( + self, user_id: str, device_id: Optional[str] + ) -> Dict[Tuple[str, str], dict]: """For each device_id listed, give the user_ip it was last seen on Args: - user_id (str) - device_id (str): If None fetches all devices for the user + user_id: The user to fetch devices for. + device_id: If None fetches all devices for the user Returns: - defer.Deferred: resolves to a dict, where the keys - are (user_id, device_id) tuples. The values are also dicts, with - keys giving the column names + A dictionary mapping a tuple of (user_id, device_id) to dicts, with + keys giving the column names from the devices table. """ keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id - res = yield self.db.simple_select_list( + res = await self.db_pool.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -499,8 +495,7 @@ def get_last_client_ip_by_device(self, user_id, device_id): } return ret - @defer.inlineCallbacks - def get_user_ip_and_agents(self, user): + async def get_user_ip_and_agents(self, user): user_id = user.to_string() results = {} @@ -510,7 +505,7 @@ def get_user_ip_and_agents(self, user): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], @@ -540,7 +535,7 @@ async def _prune_old_user_ips(self): # Nothing to do return - if not await self.db.updates.has_completed_background_update( + if not await self.db_pool.updates.has_completed_background_update( "devices_last_seen" ): # Only start pruning if we have finished populating the devices @@ -573,4 +568,6 @@ async def _prune_old_user_ips(self): def _prune_old_user_ips_txn(txn): txn.execute(sql, (timestamp,)) - await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn + ) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py similarity index 84% rename from synapse/storage/data_stores/main/deviceinbox.py rename to synapse/storage/databases/main/deviceinbox.py index da297b31fbbe..1f6e995c4fef 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -16,13 +16,10 @@ import logging from typing import List, Tuple -from canonicaljson import json - -from twisted.internet import defer - from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -32,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): + async def get_new_messages_for_device( + self, + user_id: str, + device_id: str, + last_stream_id: int, + current_stream_id: int, + limit: int = 100, + ) -> Tuple[List[dict], int]: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device + user_id: The recipient user_id. + device_id: The recipient device_id. + last_stream_id: The last stream ID checked. + current_stream_id: The current position of the to device message stream. + limit: The maximum number of messages to retrieve. + Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id ) if not has_changed: - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ( @@ -70,20 +74,22 @@ def get_new_messages_for_device_txn(txn): stream_pos = current_stream_id return messages, stream_pos - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @trace - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + async def delete_messages_for_device( + self, user_id: str, device_id: str, up_to_stream_id: int + ) -> int: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. + user_id: The recipient user_id. + device_id: The recipient device_id. + up_to_stream_id: Where to delete messages up to. + Returns: - A deferred that resolves to the number of messages deleted. + The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting @@ -110,7 +116,7 @@ def delete_messages_for_device_txn(txn): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.db.runInteraction( + count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -129,9 +135,9 @@ def delete_messages_for_device_txn(txn): return count @trace - def get_new_device_msgs_for_remote( + async def get_new_device_msgs_for_remote( self, destination, last_stream_id, current_stream_id, limit - ): + ) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. @@ -140,8 +146,7 @@ def get_new_device_msgs_for_remote( current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) @@ -154,11 +159,11 @@ def get_new_device_msgs_for_remote( ) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. - return defer.succeed(([], last_stream_id)) + return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): @@ -179,7 +184,7 @@ def get_new_messages_for_remote_destination_txn(txn): stream_pos = current_stream_id return messages, stream_pos - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -204,7 +209,7 @@ def delete_messages_for_remote_destination_txn(txn): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.db.runInteraction( + return self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) @@ -269,7 +274,7 @@ def get_all_new_device_messages_txn(txn): return updates, upto_token, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) @@ -277,30 +282,29 @@ def get_all_new_device_messages_txn(txn): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.db.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 @@ -308,7 +312,7 @@ def reindex_txn(conn): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been @@ -321,21 +325,21 @@ def __init__(self, database: Database, db_conn, hs): ) @trace - @defer.inlineCallbacks - def add_messages_to_device_inbox( - self, local_messages_by_user_then_device, remote_messages_by_destination - ): + async def add_messages_to_device_inbox( + self, + local_messages_by_user_then_device: dict, + remote_messages_by_destination: dict, + ) -> int: """Used to send messages from this server. Args: - sender_user_id(str): The ID of the user sending these messages. - local_messages_by_user_and_device(dict): + local_messages_by_user_and_device: Dictionary of user_id to device_id to message. - remote_messages_by_destination(dict): + remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. + Returns: - A deferred stream_id that resolves when the messages have been - inserted. + The new stream_id. """ def add_messages_txn(txn, now_ms, stream_id): @@ -354,13 +358,13 @@ def add_messages_txn(txn, now_ms, stream_id): ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = json.dumps(edu) + edu_json = json_encoder.encode(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -372,15 +376,14 @@ def add_messages_txn(txn, now_ms, stream_id): return self._device_inbox_id_gen.get_current_token() - @defer.inlineCallbacks - def add_messages_from_remote_to_device_inbox( - self, origin, message_id, local_messages_by_user_then_device - ): + async def add_messages_from_remote_to_device_inbox( + self, origin: str, message_id: str, local_messages_by_user_then_device: dict + ) -> int: def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self.db.simple_select_one_txn( + already_inserted = self.db_pool.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -392,7 +395,7 @@ def add_messages_txn(txn, now_ms, stream_id): # Add an entry for this message_id so that we know we've processed # it. - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -410,7 +413,7 @@ def add_messages_txn(txn, now_ms, stream_id): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, @@ -432,7 +435,7 @@ def _add_messages_to_local_device_inbox_txn( # Handle wildcard device_ids. sql = "SELECT device_id FROM devices WHERE user_id = ?" txn.execute(sql, (user_id,)) - message_json = json.dumps(messages_by_device["*"]) + message_json = json_encoder.encode(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. @@ -454,7 +457,7 @@ def _add_messages_to_local_device_inbox_txn( # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = json.dumps(messages_by_device[device]) + message_json = json_encoder.encode(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/databases/main/devices.py similarity index 77% rename from synapse/storage/data_stores/main/devices.py rename to synapse/storage/databases/main/devices.py index 45581a65004e..2b330604803d 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,11 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Optional, Set, Tuple - -from canonicaljson import json - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -31,17 +27,13 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( - Database, + DatabasePool, LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.types import Collection, get_verify_key_from_cross_signing_key -from synapse.util.caches.descriptors import ( - Cache, - cached, - cachedInlineCallbacks, - cachedList, -) +from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key +from synapse.util import json_encoder +from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -55,38 +47,36 @@ class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id, device_id): + def get_device(self, user_id: str, device_id: str): """Retrieve a device. Only returns devices that are not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve Returns: defer.Deferred for a dict containing the device information Raises: StoreError: if the device is not found """ - return self.db.simple_select_one( + return self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", ) - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. Args: - user_id (str): + user_id: Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. + A mapping from device_id to a dict containing "device_id", "user_id" + and "display_name" for each device. """ - devices = yield self.db.simple_select_list( + devices = await self.db_pool.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -96,19 +86,20 @@ def get_devices_by_user(self, user_id): return {d["device_id"]: d for d in devices} @trace - @defer.inlineCallbacks - def get_device_updates_by_remote(self, destination, from_stream_id, limit): + async def get_device_updates_by_remote( + self, destination: str, from_stream_id: int, limit: int + ) -> Tuple[int, List[Tuple[str, dict]]]: """Get a stream of device updates to send to the given remote server. Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - limit (int): Maximum number of device updates to return + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + limit: Maximum number of device updates to return + Returns: - Deferred[tuple[int, list[tuple[string,dict]]]]: - current stream id (ie, the stream id of the last update included in the - response), and the list of updates, where each update is a pair of EDU - type and EDU contents + A mapping from the current stream id (ie, the stream id of the last + update included in the response), and the list of updates, where + each update is a pair of EDU type and EDU contents. """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -118,7 +109,7 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit): if not has_changed: return now_stream_id, [] - updates = yield self.db.runInteraction( + updates = await self.db_pool.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -137,7 +128,7 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + cross_signing_key = await self.get_e2e_cross_signing_key(user, "master") if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -150,7 +141,7 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit): "device_id": verify_key.version, } - cross_signing_key = yield self.get_e2e_cross_signing_key( + cross_signing_key = await self.get_e2e_cross_signing_key( user, "self_signing" ) if cross_signing_key: @@ -201,7 +192,7 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - results = yield self._get_device_update_edus_by_remote( + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -214,16 +205,21 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit): return now_stream_id, results def _get_device_updates_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit + self, + txn: LoggingTransaction, + destination: str, + from_stream_id: int, + now_stream_id: int, + limit: int, ): """Return device update information for a given remote destination Args: - txn (LoggingTransaction): The transaction to execute - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - now_stream_id (int): The maximum stream_id to filter updates by, inclusive - limit (int): Maximum number of device updates to return + txn: The transaction to execute + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + now_stream_id: The maximum stream_id to filter updates by, inclusive + limit: Maximum number of device updates to return Returns: List: List of device updates @@ -239,23 +235,26 @@ def _get_device_updates_by_remote_txn( return list(txn) - @defer.inlineCallbacks - def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): + async def _get_device_update_edus_by_remote( + self, + destination: str, + from_stream_id: int, + query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]], + ) -> List[Tuple[str, dict]]: """Returns a list of device update EDUs as well as E2EE keys Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevent json-encoded + user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: - List[Dict]: List of objects representing an device update EDU - + List of objects representing an device update EDU """ devices = ( - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -270,7 +269,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m for user_id, user_devices in devices.items(): # The prev_id for the first row is always the last row before # `from_stream_id` - prev_id = yield self._get_last_device_update_for_remote_user( + prev_id = await self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) @@ -314,7 +313,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m return results def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id + self, destination: str, user_id: str, from_stream_id: int ): def f(txn): prev_sent_id_sql = """ @@ -326,19 +325,21 @@ def f(txn): rows = txn.fetchall() return rows[0][0] - return self.db.runInteraction("get_last_device_update_for_remote_user", f) + return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) - def mark_as_sent_devices_by_remote(self, destination, stream_id): + def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): """Mark that updates have successfully been sent to the destination. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, stream_id, ) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + def _mark_as_sent_devices_by_remote_txn( + self, txn: LoggingTransaction, destination: str, stream_id: int + ) -> None: # We update the device_lists_outbound_last_success with the successfully # poked users. sql = """ @@ -350,7 +351,7 @@ def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): txn.execute(sql, (destination, stream_id)) rows = txn.fetchall() - self.db.simple_upsert_many_txn( + self.db_pool.simple_upsert_many_txn( txn=txn, table="device_lists_outbound_last_success", key_names=("destination", "user_id"), @@ -366,17 +367,21 @@ def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): """ txn.execute(sql, (destination, stream_id)) - @defer.inlineCallbacks - def add_user_signature_change_to_streams(self, from_user_id, user_ids): + async def add_user_signature_change_to_streams( + self, from_user_id: str, user_ids: List[str] + ) -> int: """Persist that a user has made new signatures Args: - from_user_id (str): the user who made the signatures - user_ids (list[str]): the users who were signed + from_user_id: the user who made the signatures + user_ids: the users who were signed + + Returns: + THe new stream ID. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -385,45 +390,52 @@ def add_user_signature_change_to_streams(self, from_user_id, user_ids): ) return stream_id - def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + def _add_user_signature_change_txn( + self, + txn: LoggingTransaction, + from_user_id: str, + user_ids: List[str], + stream_id: int, + ) -> None: txn.call_after( self._user_signature_stream_cache.entity_has_changed, from_user_id, stream_id, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "user_signature_stream", values={ "stream_id": stream_id, "from_user_id": from_user_id, - "user_ids": json.dumps(user_ids), + "user_ids": json_encoder.encode(user_ids), }, ) - def get_device_stream_token(self): + def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @trace - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): + async def get_user_devices_from_cache( + self, query_list: List[Tuple[str, str]] + ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: - query_list(list): List of (user_id, device_ids), if device_ids is + query_list: List of (user_id, device_ids), if device_ids is falsey then return all device ids for that user. Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info + A tuple of (user_ids_not_in_cache, results_map), where + user_ids_not_in_cache is a set of user_ids and results_map is a + mapping of user_id -> device_id -> device_info. """ user_ids = {user_id for user_id, _ in query_list} - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists # resynced. If they do then we remove them from the cached list. - users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( + users_needing_resync = await self.get_user_ids_requiring_device_list_resync( user_ids ) user_ids_in_cache = { @@ -437,19 +449,19 @@ def get_user_devices_from_cache(self, query_list): continue if device_id: - device = yield self._get_cached_user_device(user_id, device_id) + device = await self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device else: - results[user_id] = yield self.get_cached_devices_for_user(user_id) + results[user_id] = await self.get_cached_devices_for_user(user_id) set_tag("in_cache", results) set_tag("not_in_cache", user_ids_not_in_cache) return user_ids_not_in_cache, results - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self.db.simple_select_one_onecol( + @cached(num_args=2, tree=True) + async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -457,9 +469,9 @@ def _get_cached_user_device(self, user_id, device_id): ) return db_to_json(content) - @cachedInlineCallbacks() - def get_cached_devices_for_user(self, user_id): - devices = yield self.db.simple_select_list( + @cached() + async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -469,19 +481,21 @@ def get_cached_devices_for_user(self, user_id): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id): + def get_devices_with_keys_by_user(self, user_id: str): """Get all devices (with any device keys) for a user Returns: - (stream_id, devices) + Deferred which resolves to (stream_id, devices) """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_devices_with_keys_by_user", self._get_devices_with_keys_by_user_txn, user_id, ) - def _get_devices_with_keys_by_user_txn(self, txn, user_id): + def _get_devices_with_keys_by_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: now_stream_id = self._device_list_id_gen.get_current_token() devices = self._get_e2e_device_keys_txn( @@ -514,17 +528,18 @@ def _get_devices_with_keys_by_user_txn(self, txn, user_id): return now_stream_id, [] - def get_users_whose_devices_changed(self, from_key, user_ids): + async def get_users_whose_devices_changed( + self, from_key: str, user_ids: Iterable[str] + ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key (str): The device lists stream token - user_ids (Iterable[str]) + from_key: The device lists stream token + user_ids: The user IDs to query for devices. Returns: - Deferred[set[str]]: The set of user_ids whose devices have changed - since `from_key` + The set of user_ids whose devices have changed since `from_key` """ from_key = int(from_key) @@ -535,7 +550,7 @@ def get_users_whose_devices_changed(self, from_key, user_ids): ) if not to_check: - return defer.succeed(set()) + return set() def _get_users_whose_devices_changed_txn(txn): changes = set() @@ -555,18 +570,22 @@ def _get_users_whose_devices_changed_txn(txn): return changes - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - @defer.inlineCallbacks - def get_users_whose_signatures_changed(self, user_id, from_key): + async def get_users_whose_signatures_changed( + self, user_id: str, from_key: str + ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. Args: - user_id (str): the user who made the signatures - from_key (str): The device lists stream token + user_id: the user who made the signatures + from_key: The device lists stream token + + Returns: + A set of user IDs with updated signatures. """ from_key = int(from_key) if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): @@ -574,7 +593,7 @@ def get_users_whose_signatures_changed(self, user_id, from_key): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self.db.execute( + rows = await self.db_pool.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return {user for row in rows for user in db_to_json(row[0])} @@ -600,7 +619,7 @@ async def get_all_device_list_changes_for_remotes( between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ @@ -631,17 +650,17 @@ def _get_all_device_list_changes_for_remotes(txn): return updates, upto_token, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_device_list_changes_for_remotes", _get_all_device_list_changes_for_remotes, ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): + def get_device_list_last_stream_id_for_remote(self, user_id: str): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -654,8 +673,8 @@ def get_device_list_last_stream_id_for_remote(self, user_id): list_name="user_ids", inlineCallbacks=True, ) - def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self.db.simple_select_many_batch( + def get_device_list_last_stream_id_for_remotes(self, user_ids: str): + rows = yield self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -668,8 +687,7 @@ def get_device_list_last_stream_id_for_remotes(self, user_ids): return results - @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync( + async def get_user_ids_requiring_device_list_resync( self, user_ids: Optional[Collection[str]] = None, ) -> Set[str]: """Given a list of remote users return the list of users that we @@ -680,7 +698,7 @@ def get_user_ids_requiring_device_list_resync( The IDs of users whose device lists need resync. """ if user_ids: - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_resync", column="user_id", iterable=user_ids, @@ -688,7 +706,7 @@ def get_user_ids_requiring_device_list_resync( desc="get_user_ids_requiring_device_list_resync_with_iterable", ) else: - rows = yield self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="device_lists_remote_resync", keyvalues=None, retcols=("user_id",), @@ -701,7 +719,7 @@ def mark_remote_user_device_cache_as_stale(self, user_id: str): """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, @@ -709,12 +727,12 @@ def mark_remote_user_device_cache_as_stale(self, user_id: str): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id): + def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): """Mark that we no longer track device lists for remote user. """ def _mark_remote_user_device_list_as_unsubscribed_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -723,17 +741,17 @@ def _mark_remote_user_device_list_as_unsubscribed_txn(txn): txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "mark_remote_user_device_list_as_unsubscribed", _mark_remote_user_device_list_as_unsubscribed_txn, ) class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -741,7 +759,7 @@ def __init__(self, database: Database, db_conn, hs): ) # create a unique index on device_lists_remote_cache - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -750,7 +768,7 @@ def __init__(self, database: Database, db_conn, hs): ) # And one on device_lists_remote_extremeties - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -759,35 +777,34 @@ def __init__(self, database: Database, db_conn, hs): ) # once they complete, we can remove the old non-unique indexes. - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) # clear out duplicate device list outbound pokes - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, ) # a pair of background updates that were added during the 1.14 release cycle, # but replaced with 58/06dlols_unique_idx.py - self.db.updates.register_noop_background_update( + self.db_pool.updates.register_noop_background_update( "device_lists_outbound_last_success_unique_idx", ) - self.db.updates.register_noop_background_update( + self.db_pool.updates.register_noop_background_update( "drop_device_lists_outbound_last_success_non_unique_idx", ) - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.db.runWithConnection(f) - yield self.db.updates._end_background_update( + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES ) return 1 @@ -807,7 +824,7 @@ async def _remove_duplicate_outbound_pokes(self, progress, batch_size): def _txn(txn): clause, args = make_tuple_comparison_clause( - self.db.engine, [(x, last_row[x]) for x in KEY_COLS] + self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS] ) sql = """ SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts @@ -823,30 +840,32 @@ def _txn(txn): ",".join(KEY_COLS), # ORDER BY ) txn.execute(sql, args + [batch_size]) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) row = None for row in rows: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, ) row["sent"] = False - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "device_lists_outbound_pokes", row, ) if row: - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, ) return len(rows) - rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) + rows = await self.db_pool.runInteraction( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn + ) if not rows: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES ) @@ -854,7 +873,7 @@ def _txn(txn): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies @@ -865,18 +884,20 @@ def __init__(self, database: Database, db_conn, hs): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): + async def store_device( + self, user_id: str, device_id: str, initial_device_display_name: str + ) -> bool: """Ensure the given device is known; add it to the store if not Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. + user_id: id of user associated with the device + device_id: id of device + initial_device_display_name: initial displayname of the device. + Ignored if device exists. + Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. + Whether the device was inserted or an existing device existed with that ID. + Raises: StoreError: if the device is already in use """ @@ -885,7 +906,7 @@ def store_device(self, user_id, device_id, initial_device_display_name): return False try: - inserted = yield self.db.simple_insert( + inserted = await self.db_pool.simple_insert( "devices", values={ "user_id": user_id, @@ -899,7 +920,7 @@ def store_device(self, user_id, device_id, initial_device_display_name): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.db.simple_select_one_onecol( + hidden = await self.db_pool.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -924,17 +945,14 @@ def store_device(self, user_id, device_id, initial_device_display_name): ) raise StoreError(500, "Problem storing device.") - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """Delete a device. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the device + device_id: The ID of the device to delete """ - yield self.db.simple_delete_one( + await self.db_pool.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -942,17 +960,14 @@ def delete_device(self, user_id, device_id): self.device_id_exists_cache.invalidate((user_id, device_id)) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the devices + device_ids: The IDs of the devices to delete """ - yield self.db.simple_delete_many( + await self.db_pool.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -962,26 +977,25 @@ def delete_devices(self, user_id, device_ids): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - def update_device(self, user_id, device_id, new_display_name=None): + async def update_device( + self, user_id: str, device_id: str, new_display_name: Optional[str] = None + ) -> None: """Update a device. Only updates the device if it is not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged + user_id: The ID of the user which owns the device + device_id: The ID of the device to update + new_display_name: new displayname for device; None to leave unchanged Raises: StoreError: if the device is not found - Returns: - defer.Deferred """ updates = {} if new_display_name is not None: updates["display_name"] = new_display_name if not updates: - return defer.succeed(None) - return self.db.simple_update_one( + return None + await self.db_pool.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -989,7 +1003,7 @@ def update_device(self, user_id, device_id, new_display_name=None): ) def update_remote_device_list_cache_entry( - self, user_id, device_id, content, stream_id + self, user_id: str, device_id: str, content: JsonDict, stream_id: int ): """Updates a single device in the cache of a remote user's devicelist. @@ -997,15 +1011,15 @@ def update_remote_device_list_cache_entry( device list. Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list + user_id: User to update device list for + device_id: ID of decivice being updated + content: new data on this device + stream_id: the version of the device list Returns: Deferred[None] """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -1015,10 +1029,15 @@ def update_remote_device_list_cache_entry( ) def _update_remote_device_list_cache_entry_txn( - self, txn, user_id, device_id, content, stream_id - ): + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + content: JsonDict, + stream_id: int, + ) -> None: if content.get("deleted"): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -1026,11 +1045,11 @@ def _update_remote_device_list_cache_entry_txn( txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json.dumps(content)}, + values={"content": json_encoder.encode(content)}, # we don't need to lock, because we assume we are the only thread # updating this user's devices. lock=False, @@ -1042,7 +1061,7 @@ def _update_remote_device_list_cache_entry_txn( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -1052,21 +1071,23 @@ def _update_remote_device_list_cache_entry_txn( lock=False, ) - def update_remote_device_list_cache(self, user_id, devices, stream_id): + def update_remote_device_list_cache( + self, user_id: str, devices: List[dict], stream_id: int + ): """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's device list. Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list + user_id: User to update device list for + devices: list of device objects supplied over federation + stream_id: the version of the device list Returns: Deferred[None] """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -1074,19 +1095,21 @@ def update_remote_device_list_cache(self, user_id, devices, stream_id): stream_id, ) - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self.db.simple_delete_txn( + def _update_remote_device_list_cache_txn( + self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int + ): + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ { "user_id": user_id, "device_id": content["device_id"], - "content": json.dumps(content), + "content": json_encoder.encode(content), } for content in devices ], @@ -1098,7 +1121,7 @@ def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id) self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -1111,12 +1134,13 @@ def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id) # If we're replacing the remote user's device list cache presumably # we've done a full resync, so we remove the entry that says we need # to resync - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, ) - @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_ids, hosts): + async def add_device_change_to_streams( + self, user_id: str, device_ids: Collection[str], hosts: List[str] + ): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ @@ -1124,7 +1148,7 @@ def add_device_change_to_streams(self, user_id, device_ids, hosts): return with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, user_id, @@ -1139,7 +1163,7 @@ def add_device_change_to_streams(self, user_id, device_ids, hosts): with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_device_outbound_poke_to_stream", self._add_device_outbound_poke_to_stream_txn, user_id, @@ -1174,7 +1198,7 @@ def _add_device_change_to_stream_txn( [(user_id, device_id, min_stream_id) for device_id in device_ids], ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1184,7 +1208,13 @@ def _add_device_change_to_stream_txn( ) def _add_device_outbound_poke_to_stream_txn( - self, txn, user_id, device_ids, hosts, stream_ids, context, + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + hosts: List[str], + stream_ids: List[str], + context: Dict[str, str], ): for host in hosts: txn.call_after( @@ -1196,7 +1226,7 @@ def _add_device_outbound_poke_to_stream_txn( now = self._clock.time_msec() next_stream_id = iter(stream_ids) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ @@ -1207,7 +1237,7 @@ def _add_device_outbound_poke_to_stream_txn( "device_id": device_id, "sent": False, "ts": now, - "opentracing_context": json.dumps(context) + "opentracing_context": json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", } @@ -1216,7 +1246,7 @@ def _add_device_outbound_poke_to_stream_txn( ], ) - def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): + def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure that we don't fill up due to dead servers. @@ -1303,7 +1333,7 @@ def _prune_txn(txn): return run_as_background_process( "prune_old_outbound_device_pokes", - self.db.runInteraction, + self.db_pool.runInteraction, "_prune_old_outbound_device_pokes", _prune_txn, ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/databases/main/directory.py similarity index 77% rename from synapse/storage/data_stores/main/directory.py rename to synapse/storage/databases/main/directory.py index e1d1bc3e0586..037e02603c7b 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,30 +14,29 @@ # limitations under the License. from collections import namedtuple -from typing import Optional - -from twisted.internet import defer +from typing import Iterable, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) class DirectoryWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): - """ Get's the room_id and server list for a given room_alias + async def get_association_from_room_alias( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: + """Gets the room_id and server list for a given room_alias Args: - room_alias (RoomAlias) + room_alias: The alias to translate to an ID. Returns: - Deferred: results in namedtuple with keys "room_id" and - "servers" or None if no association can be found + The room alias mapping or None if no association can be found. """ - room_id = yield self.db.simple_select_one_onecol( + room_id = await self.db_pool.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -48,7 +47,7 @@ def get_association_from_room_alias(self, room_alias): if not room_id: return None - servers = yield self.db.simple_select_onecol( + servers = await self.db_pool.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -61,7 +60,7 @@ def get_association_from_room_alias(self, room_alias): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -70,7 +69,7 @@ def get_room_alias_creator(self, room_alias): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -79,22 +78,24 @@ def get_aliases_for_room(self, room_id): class DirectoryStore(DirectoryWorkerStore): - @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers, creator=None): + async def create_room_alias_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Iterable[str], + creator: Optional[str] = None, + ) -> None: """ Creates an association between a room alias and room_id/servers Args: - room_alias (RoomAlias) - room_id (str) - servers (list) - creator (str): Optional user_id of creator. - - Returns: - Deferred + room_alias: The alias to create. + room_id: The target of the alias. + servers: A list of servers through which it may be possible to join the room + creator: Optional user_id of creator. """ def alias_txn(txn): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "room_aliases", { @@ -104,7 +105,7 @@ def alias_txn(txn): }, ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="room_alias_servers", values=[ @@ -118,24 +119,22 @@ def alias_txn(txn): ) try: - ret = yield self.db.runInteraction( + await self.db_pool.runInteraction( "create_room_alias_association", alias_txn ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) - return ret - @defer.inlineCallbacks - def delete_room_alias(self, room_alias): - room_id = yield self.db.runInteraction( + async def delete_room_alias(self, room_alias: RoomAlias) -> str: + room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias): + def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), @@ -190,6 +189,6 @@ def _update_aliases_for_room_txn(txn): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py similarity index 90% rename from synapse/storage/data_stores/main/e2e_room_keys.py rename to synapse/storage/databases/main/e2e_room_keys.py index 615364f01837..2eeb9f97dc14 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,18 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json - -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks - def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + async def update_e2e_room_key( + self, user_id, version, room_id, session_id, room_key + ): """Replaces the encrypted E2E room key for a given session in a given backup Args: @@ -38,7 +36,7 @@ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): StoreError """ - yield self.db.simple_update_one( + await self.db_pool.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -50,13 +48,12 @@ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), }, desc="update_e2e_room_key", ) - @defer.inlineCallbacks - def add_e2e_room_keys(self, user_id, version, room_keys): + async def add_e2e_room_keys(self, user_id, version, room_keys): """Bulk add room keys to a given backup. Args: @@ -77,7 +74,7 @@ def add_e2e_room_keys(self, user_id, version, room_keys): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), } ) log_kv( @@ -89,13 +86,12 @@ def add_e2e_room_keys(self, user_id, version, room_keys): } ) - yield self.db.simple_insert_many( + await self.db_pool.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace - @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. @@ -110,7 +106,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): the backup (or for the specified room) Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -125,7 +121,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): if session_id: keyvalues["session_id"] = session_id - rows = yield self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -171,7 +167,7 @@ def get_e2e_room_keys_multi(self, user_id, version, room_keys): Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -235,7 +231,7 @@ def count_e2e_room_keys(self, user_id, version): version (str): the version ID of the backup we're querying about """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", @@ -243,8 +239,9 @@ def count_e2e_room_keys(self, user_id, version): ) @trace - @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -259,7 +256,7 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): the backup (or for the specified room) Returns: - A deferred of the deletion transaction + The deletion transaction """ keyvalues = {"user_id": user_id, "version": int(version)} @@ -268,7 +265,7 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): if session_id: keyvalues["session_id"] = session_id - yield self.db.simple_delete( + await self.db_pool.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -313,7 +310,7 @@ def _get_e2e_room_keys_version_info_txn(txn): # it isn't there. raise StoreError(404, "No row found") - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, @@ -325,7 +322,7 @@ def _get_e2e_room_keys_version_info_txn(txn): result["etag"] = 0 return result - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @@ -353,20 +350,20 @@ def _create_e2e_room_keys_version_txn(txn): new_version = str(int(current_version) + 1) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ "user_id": user_id, "version": new_version, "algorithm": info["algorithm"], - "auth_data": json.dumps(info["auth_data"]), + "auth_data": json_encoder.encode(info["auth_data"]), }, ) return new_version - return self.db.runInteraction( + return self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -387,12 +384,12 @@ def update_e2e_room_keys_version( updatevalues = {} if info is not None and "auth_data" in info: - updatevalues["auth_data"] = json.dumps(info["auth_data"]) + updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) if version_etag is not None: updatevalues["etag"] = version_etag if updatevalues: - return self.db.simple_update( + return self.db_pool.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, @@ -421,19 +418,19 @@ def _delete_e2e_room_keys_version_txn(txn): else: this_version = version - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self.db.simple_update_one_txn( + return self.db_pool.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py similarity index 89% rename from synapse/storage/data_stores/main/end_to_end_keys.py rename to synapse/storage/databases/main/end_to_end_keys.py index 317c07a8297c..f93e0d320dcf 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,24 +14,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, Iterable, List, Optional, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection -from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter class EndToEndKeyWorkerStore(SQLBaseStore): @trace - @defer.inlineCallbacks - def get_e2e_device_keys( + async def get_e2e_device_keys( self, query_list, include_all_devices=False, include_deleted_devices=False ): """Fetch a list of device keys. @@ -51,7 +50,7 @@ def get_e2e_device_keys( if not query_list: return {} - results = yield self.db.runInteraction( + results = await self.db_pool.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -128,7 +127,7 @@ def _get_e2e_device_keys_txn( ) txn.execute(sql, query_params) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) result = {} for row in rows: @@ -146,7 +145,7 @@ def _get_e2e_device_keys_txn( ) txn.execute(signature_sql, signature_query_params) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) # add each cross-signing signature to the correct device in the result dict. for row in rows: @@ -174,8 +173,9 @@ def _get_e2e_device_keys_txn( log_kv(result) return result - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + async def get_e2e_one_time_keys( + self, user_id: str, device_id: str, key_ids: List[str] + ) -> Dict[Tuple[str, str], str]: """Retrieve a number of one-time keys for a user Args: @@ -185,11 +185,10 @@ def get_e2e_one_time_keys(self, user_id, device_id, key_ids): retrieve Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key + A map from (algorithm, key_id) to json string for key """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -201,17 +200,21 @@ def get_e2e_one_time_keys(self, user_id, device_id, key_ids): log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) return result - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + async def add_e2e_one_time_keys( + self, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: """Insert some new one time keys for a device. Errors if any of the keys already exist. Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ def _add_e2e_one_time_keys(txn): @@ -222,7 +225,7 @@ def _add_e2e_one_time_keys(txn): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -241,7 +244,7 @@ def _add_e2e_one_time_keys(txn): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -264,26 +267,27 @@ def _count_e2e_one_time_keys(txn): result[algorithm] = key_count return result - return self.db.runInteraction( + return self.db_pool.runInteraction( "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - @defer.inlineCallbacks - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + async def get_e2e_cross_signing_key( + self, user_id: str, key_type: str, from_user_id: Optional[str] = None + ) -> Optional[dict]: """Returns a user's cross-signing key. Args: - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' + user_id: the user whose key is being requested + key_type: the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on + from_user_id: if specified, signatures made by this user on the self-signing key will be included in the result Returns: dict of the key data or None if not found """ - res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) user_keys = res.get(user_id) if not user_keys: return None @@ -318,7 +322,7 @@ def _get_bare_e2e_cross_signing_keys_bulk( to None. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, @@ -361,7 +365,7 @@ def _get_bare_e2e_cross_signing_keys_bulk_txn( ) txn.execute(sql, params) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) for row in rows: user_id = row["user_id"] @@ -420,7 +424,7 @@ def _get_e2e_cross_signing_signatures_txn( query_params.extend(item) txn.execute(sql, query_params) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) # and add the signatures to the appropriate keys for row in rows: @@ -449,28 +453,26 @@ def _get_e2e_cross_signing_signatures_txn( return keys - @defer.inlineCallbacks - def get_e2e_cross_signing_keys_bulk( - self, user_ids: List[str], from_user_id: str = None - ) -> defer.Deferred: + async def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: Optional[str] = None + ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. Args: - user_ids (list[str]): the users whose keys are being requested - from_user_id (str): if specified, signatures made by this user on + user_ids: the users whose keys are being requested + from_user_id: if specified, signatures made by this user on the self-signing keys will be included in the result Returns: - Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to - key data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A map of user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, + or their user ID will map to None. """ - result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, result, @@ -531,7 +533,7 @@ def _get_all_user_signature_changes_for_remotes_txn(txn): return updates, upto_token, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_user_signature_changes_for_remotes", _get_all_user_signature_changes_for_remotes_txn, ) @@ -549,7 +551,7 @@ def _set_e2e_device_keys_txn(txn): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self.db.simple_select_one_onecol_txn( + old_key_json = self.db_pool.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -565,7 +567,7 @@ def _set_e2e_device_keys_txn(txn): log_kv({"Message": "Device key already stored."}) return False - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -574,7 +576,9 @@ def _set_e2e_device_keys_txn(txn): log_kv({"message": "Device keys stored."}) return True - return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + return self.db_pool.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn + ) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" @@ -613,7 +617,7 @@ def _claim_e2e_one_time_keys(txn): ) return result - return self.db.runInteraction( + return self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) @@ -626,12 +630,12 @@ def delete_e2e_keys_by_device_txn(txn): "user_id": user_id, } ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -640,7 +644,7 @@ def delete_e2e_keys_by_device_txn(txn): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) @@ -679,7 +683,7 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): # We only need to do this for local users, since remote servers should be # responsible for checking this for their own users. if self.hs.is_mine_id(user_id): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "devices", values={ @@ -692,13 +696,13 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ "user_id": user_id, "keytype": key_type, - "keydata": json.dumps(key), + "keydata": json_encoder.encode(key), "stream_id": stream_id, }, ) @@ -715,7 +719,7 @@ def set_e2e_cross_signing_key(self, user_id, key_type, key): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, user_id, @@ -730,7 +734,7 @@ def store_e2e_cross_signing_signatures(self, user_id, signatures): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self.db.simple_insert_many( + return self.db_pool.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/databases/main/event_federation.py similarity index 92% rename from synapse/storage/data_stores/main/event_federation.py rename to synapse/storage/databases/main/event_federation.py index a6bb3221ff21..484875f98992 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -15,16 +15,14 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, List, Optional, Set, Tuple - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.signatures import SignatureWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter @@ -65,7 +63,7 @@ def get_auth_chain_ids( Returns: list of event_ids """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, @@ -114,7 +112,7 @@ def get_auth_chain_difference(self, state_sets: List[Set[str]]): Deferred[Set[str]] """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, @@ -260,12 +258,12 @@ def _get_auth_chain_difference_txn( return {eid for eid, n in event_to_missing_sets.items() if n} def get_oldest_events_in_room(self, room_id): - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id ) def get_oldest_events_with_depth_in_room(self, room_id): - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -286,17 +284,13 @@ def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): return dict(txn) - @defer.inlineCallbacks - def get_max_depth_of(self, event_ids): + async def get_max_depth_of(self, event_ids: List[str]) -> int: """Returns the max depth of a set of event IDs Args: - event_ids (list[str]) - - Returns - Deferred[int] + event_ids: The event IDs to calculate the max depth of. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -310,7 +304,7 @@ def get_max_depth_of(self, event_ids): return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.db.simple_select_onecol_txn( + return self.db_pool.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -332,7 +326,7 @@ def get_prev_events_for_room(self, room_id: str): """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) @@ -387,13 +381,13 @@ def _get_rooms_with_many_extremities_txn(txn): txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -403,12 +397,12 @@ def get_latest_event_ids_in_room(self, room_id): def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self.db.simple_select_one_onecol_txn( + min_depth = self.db_pool.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -474,7 +468,7 @@ def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) @@ -489,7 +483,7 @@ def get_backfill_events(self, room_id, event_list, limit): limit (int) """ return ( - self.db.runInteraction( + self.db_pool.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, @@ -520,7 +514,7 @@ def _get_backfill_events(self, txn, room_id, event_list, limit): queue = PriorityQueue() for event_id in event_list: - depth = self.db.simple_select_one_onecol_txn( + depth = self.db_pool.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -550,9 +544,8 @@ def _get_backfill_events(self, txn, room_id, event_list, limit): return event_results - @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.db.runInteraction( + async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -560,7 +553,7 @@ def get_missing_events(self, room_id, earliest_events, latest_events, limit): latest_events, limit, ) - events = yield self.get_events_as_list(ids) + events = await self.get_events_as_list(ids) return events def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): @@ -595,17 +588,13 @@ def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limi event_results.reverse() return event_results - @defer.inlineCallbacks - def get_successor_events(self, event_ids): + async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]: """Fetch all events that have the given events as a prev event Args: - event_ids (iterable[str]) - - Returns: - Deferred[list[str]] + event_ids: The events to use as the previous events. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -628,10 +617,10 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventFederationStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -658,13 +647,13 @@ def _delete_old_forward_extrem_cache_txn(txn): return run_as_background_process( "delete_old_forward_extrem_cache", - self.db.runInteraction, + self.db_pool.runInteraction, "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, ) def clean_room_for_join(self, room_id): - return self.db.runInteraction( + return self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) @@ -674,8 +663,7 @@ def _clean_room_for_join_txn(self, txn, room_id): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - @defer.inlineCallbacks - def _background_delete_non_state_event_auth(self, progress, batch_size): + async def _background_delete_non_state_event_auth(self, progress, batch_size): def delete_event_auth(txn): target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") @@ -708,17 +696,19 @@ def delete_event_auth(txn): "max_stream_id_exclusive": min_stream_id, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) return min_stream_id >= target_min_stream_id - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) + await self.db_pool.updates._end_background_update( + self.EVENT_AUTH_STATE_ONLY + ) return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py similarity index 95% rename from synapse/storage/data_stores/main/event_push_actions.py rename to synapse/storage/databases/main/event_push_actions.py index ad828389017b..7c246d3e4c48 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -17,11 +17,10 @@ import logging from typing import List -from canonicaljson import json - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -50,7 +49,7 @@ def _serialize_action(actions, is_highlight): else: if actions == DEFAULT_NOTIF_ACTION: return "" - return json.dumps(actions) + return json_encoder.encode(actions) def _deserialize_action(actions, is_highlight): @@ -66,7 +65,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn @@ -91,7 +90,7 @@ def __init__(self, database: Database, db_conn, hs): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.db.runInteraction( + ret = yield self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, @@ -176,7 +175,7 @@ def f(txn): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = await self.db.runInteraction("get_push_action_users_in_range", f) + ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) return ret async def get_unread_push_actions_for_user_in_range_for_http( @@ -230,7 +229,7 @@ def get_after_receipt(txn): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = await self.db.runInteraction( + after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -258,7 +257,7 @@ def get_no_receipt(txn): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = await self.db.runInteraction( + no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -332,7 +331,7 @@ def get_after_receipt(txn): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = await self.db.runInteraction( + after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -360,7 +359,7 @@ def get_no_receipt(txn): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = await self.db.runInteraction( + no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -410,7 +409,7 @@ def _get_if_maybe_push_in_range_for_user_txn(txn): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) @@ -461,7 +460,7 @@ def _add_push_actions_to_staging_txn(txn): ), ) - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) @@ -471,7 +470,7 @@ async def remove_push_actions_from_staging(self, event_id: str) -> None: """ try: - res = await self.db.simple_delete( + res = await self.db_pool.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -488,7 +487,7 @@ async def remove_push_actions_from_staging(self, event_id: str) -> None: def _find_stream_orderings_for_times(self): return run_as_background_process( "event_push_action_stream_orderings", - self.db.runInteraction, + self.db_pool.runInteraction, "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) @@ -524,7 +523,7 @@ def find_first_stream_ordering_after_ts(self, ts): Deferred[int]: stream ordering of the first event received on/after the timestamp """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -619,24 +618,26 @@ def f(txn): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = await self.db.runInteraction("get_time_of_last_push_action_before", f) + result = await self.db_pool.runInteraction( + "get_time_of_last_push_action_before", f + ) return result[0] if result else None class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventPushActionsStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", @@ -678,9 +679,9 @@ def f(txn): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - push_actions = await self.db.runInteraction("get_push_actions_for_user", f) + push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions @@ -690,7 +691,7 @@ def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = await self.db.runInteraction( + result = await self.db_pool.runInteraction( "get_latest_push_action_stream_ordering", f ) return result[0] or 0 @@ -753,7 +754,7 @@ async def _rotate_notifs(self): while True: logger.info("Rotating notifications") - caught_up = await self.db.runInteraction( + caught_up = await self.db_pool.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: @@ -767,7 +768,7 @@ def _rotate_notifs_txn(self, txn): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -803,7 +804,7 @@ def _rotate_notifs_txn(self, txn): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -835,7 +836,7 @@ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/databases/main/events.py similarity index 94% rename from synapse/storage/data_stores/main/events.py rename to synapse/storage/databases/main/events.py index 0c9c02afa181..1a68bf32cb3c 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -32,8 +32,8 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause -from synapse.storage.data_stores.main.search import SearchEntry -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder @@ -41,7 +41,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.data_stores.main import DataStore + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -53,47 +53,6 @@ ["type", "origin_type", "origin_entity"], ) -STATE_EVENT_TYPES_TO_MARK_UNREAD = { - EventTypes.Topic, - EventTypes.Name, - EventTypes.RoomAvatar, - EventTypes.Tombstone, -} - - -def should_count_as_unread(event: EventBase, context: EventContext) -> bool: - # Exclude rejected and soft-failed events. - if context.rejected or event.internal_metadata.is_soft_failed(): - return False - - # Exclude notices. - if ( - not event.is_state() - and event.type == EventTypes.Message - and event.content.get("msgtype") == "m.notice" - ): - return False - - # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: - return False - - # Mark events that have a non-empty string body as unread. - body = event.content.get("body") - if isinstance(body, str) and body: - return True - - # Mark some state events as unread. - if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: - return True - - # Mark encrypted events as unread. - if not event.is_state() and event.type == EventTypes.Encrypted: - return True - - return False - def encode_json(json_object): """ @@ -132,9 +91,11 @@ class PersistEventsStore: Note: This is not part of the `DataStore` mixin. """ - def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): + def __init__( + self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" + ): self.hs = hs - self.db = db + self.db_pool = db self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() @@ -207,7 +168,7 @@ def _persist_events_and_state_updates( for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, @@ -237,10 +198,6 @@ def _persist_events_and_state_updates( event_counter.labels(event.type, origin_type, origin_entity).inc() - self.store.get_unread_message_count_for_user.invalidate_many( - (event.room_id,), - ) - for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -283,7 +240,7 @@ def _get_events_which_are_prevs_txn(txn, batch): results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) @@ -347,7 +304,7 @@ def _get_prevs_before_rejected_txn(txn, batch): existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) @@ -421,7 +378,7 @@ def _persist_events_txn( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -484,7 +441,7 @@ def _update_current_state_txn( """ txn.execute(sql, (stream_id, room_id)) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, ) else: @@ -632,7 +589,7 @@ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): creator = content.get("creator") room_version_id = content.get("room_version", RoomVersions.V1.identifier) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="rooms", keyvalues={"room_id": room_id}, @@ -644,14 +601,14 @@ def _update_forward_extremities_txn( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in new_forward_extremities.items(): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after( self.store.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -664,7 +621,7 @@ def _update_forward_extremities_txn( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -788,7 +745,7 @@ def _update_outliers_txn(self, txn, events_and_contexts): # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -826,7 +783,7 @@ def event_dict(event): d.pop("redacted_because", None) return d - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_json", values=[ @@ -843,7 +800,7 @@ def event_dict(event): ], ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="events", values=[ @@ -862,9 +819,8 @@ def event_dict(event): "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), - "count_as_unread": should_count_as_unread(event, context), } - for event, context in events_and_contexts + for event, _ in events_and_contexts ], ) @@ -873,7 +829,7 @@ def event_dict(event): # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -1015,7 +971,9 @@ def _update_metadata_tables_txn( state_values.append(vals) - self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) + self.db_pool.simple_insert_many_txn( + txn, table="state_events", values=state_values + ) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1046,7 +1004,7 @@ def _add_to_cache(self, txn, events_and_contexts): ) txn.execute(sql + clause, args) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: @@ -1066,7 +1024,7 @@ def _store_redaction(self, txn, event): # invalidate the cache for the redacted event txn.call_after(self.store._invalidate_get_event_cache, event.redacts) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="redactions", values={ @@ -1089,7 +1047,7 @@ def insert_labels_for_event_txn( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self.db.simple_insert_many_txn( + return self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1111,7 +1069,7 @@ def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): event_id (str): The event ID the expiry timestamp is associated with. expiry_ts (int): The timestamp at which to expire (delete) the event. """ - return self.db.simple_insert_txn( + return self.db_pool.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, @@ -1135,12 +1093,14 @@ def _store_event_reference_hashes_txn(self, txn, events): } ) - self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.db_pool.simple_insert_many_txn( + txn, table="event_reference_hashes", values=vals + ) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1180,7 +1140,7 @@ def _store_room_members_txn(self, txn, events, backfilled): and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="local_current_membership", keyvalues={"room_id": event.room_id, "user_id": event.state_key}, @@ -1218,7 +1178,7 @@ def _handle_event_relations(self, txn, event): aggregation_key = relation.get("key") - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="event_relations", values={ @@ -1246,7 +1206,7 @@ def _handle_redaction(self, txn, redacted_event_id): redacted_event_id (str): The event that was redacted. """ - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) @@ -1282,7 +1242,7 @@ def _store_retention_policy_for_room_txn(self, txn, event): # Ignore the event if one of the value isn't an integer. return - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -1363,7 +1323,7 @@ def _set_push_actions_for_event_and_users_txn( ) for event, _ in events_and_contexts: - user_ids = self.db.simple_select_onecol_txn( + user_ids = self.db_pool.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -1395,7 +1355,7 @@ def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): ) def _store_rejections_txn(self, txn, event_id, reason): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="rejections", values={ @@ -1421,7 +1381,7 @@ def _store_event_state_mappings_txn( state_groups[event.event_id] = context.state_group - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ @@ -1443,7 +1403,7 @@ def _update_min_depth_for_room_txn(self, txn, room_id, depth): if min_depth is not None and depth >= min_depth: return - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -1455,7 +1415,7 @@ def _handle_mult_prev_events(self, txn, events): For the given event, update the event edges table and forward and backward extremities tables. """ - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_edges", values=[ diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py similarity index 90% rename from synapse/storage/data_stores/main/events_bg_updates.py rename to synapse/storage/databases/main/events_bg_updates.py index 663c94b24fc8..35a0e09e3c83 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -19,7 +19,7 @@ from synapse.api.constants import EventContentFields from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool logger = logging.getLogger(__name__) @@ -30,18 +30,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -52,7 +52,7 @@ def __init__(self, database: Database, db_conn, hs): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -61,16 +61,16 @@ def __init__(self, database: Database, db_conn, hs): psql_only=True, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "redactions_received_ts", self._redactions_received_ts ) # This index gets deleted in `event_fix_redactions_bytes` update - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_fix_redactions_bytes_create_index", index_name="redactions_censored_redacts", table="redactions", @@ -78,15 +78,15 @@ def __init__(self, database: Database, db_conn, hs): where_clause="have_censored", ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "event_store_labels", self._event_store_labels ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "redactions_have_censored_ts_idx", index_name="redactions_have_censored_ts", table="redactions", @@ -149,18 +149,18 @@ def reindex_txn(txn): "rows_inserted": rows_inserted + len(rows), } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) - result = yield self.db.runInteraction( + result = yield self.db_pool.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: - yield self.db.updates._end_background_update( + yield self.db_pool.updates._end_background_update( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME ) @@ -195,7 +195,7 @@ def reindex_search_txn(txn): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self.db.simple_select_many_txn( + ev_rows = self.db_pool.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -228,18 +228,18 @@ def reindex_search_txn(txn): "rows_inserted": rows_inserted + len(rows_to_update), } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) - result = yield self.db.runInteraction( + result = yield self.db_pool.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: - yield self.db.updates._end_background_update( + yield self.db_pool.updates._end_background_update( self.EVENT_ORIGIN_SERVER_TS_NAME ) @@ -374,7 +374,7 @@ def _cleanup_extremities_bg_update_txn(txn): to_delete.intersection_update(original_set) - deleted = self.db.simple_delete_many_txn( + deleted = self.db_pool.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -390,7 +390,7 @@ def _cleanup_extremities_bg_update_txn(txn): if deleted: # We now need to invalidate the caches of these rooms - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="events", column="event_id", @@ -404,7 +404,7 @@ def _cleanup_extremities_bg_update_txn(txn): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -414,19 +414,19 @@ def _cleanup_extremities_bg_update_txn(txn): return len(original_set) - num_handled = yield self.db.runInteraction( + num_handled = yield self.db_pool.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: - yield self.db.updates._end_background_update( + yield self.db_pool.updates._end_background_update( self.DELETE_SOFT_FAILED_EXTREMITIES ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) @@ -474,18 +474,18 @@ def _redactions_received_ts_txn(txn): txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} ) return len(rows) - count = yield self.db.runInteraction( + count = yield self.db_pool.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) if not count: - yield self.db.updates._end_background_update("redactions_received_ts") + yield self.db_pool.updates._end_background_update("redactions_received_ts") return count @@ -511,11 +511,11 @@ def _event_fix_redactions_bytes_txn(txn): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self.db.updates._end_background_update("event_fix_redactions_bytes") + yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes") return 1 @@ -543,7 +543,7 @@ def _event_store_labels_txn(txn): try: event_json = db_to_json(event_json_raw) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -569,17 +569,17 @@ def _event_store_labels_txn(txn): nbrows += 1 last_row_event_id = event_id - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "event_store_labels", {"last_event_id": last_row_event_id} ) return nbrows - num_rows = yield self.db.runInteraction( + num_rows = yield self.db_pool.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) if not num_rows: - yield self.db.updates._end_background_update("event_store_labels") + yield self.db_pool.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/databases/main/events_worker.py similarity index 93% rename from synapse/storage/data_stores/main/events_worker.py rename to synapse/storage/databases/main/events_worker.py index b03b25963691..755b7a2a85d4 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -40,16 +40,10 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database -from synapse.storage.types import Cursor +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import ( - Cache, - _CacheContext, - cached, - cachedInlineCallbacks, -) +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -80,7 +74,7 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.writers.events == hs.get_instance_name(): @@ -136,7 +130,7 @@ def get_received_ts(self, event_id): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -175,7 +169,7 @@ def _get_approximate_received_ts_txn(txn): return ts - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_approximate_received_ts", _get_approximate_received_ts_txn ) @@ -543,7 +537,7 @@ def _fetch_event_list(self, conn, event_list): event_id for events, _ in event_list for event_id in events } - row_dict = self.db.new_transaction( + row_dict = self.db_pool.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -720,7 +714,7 @@ def _enqueue_events(self, events): if should_start: run_as_background_process( - "fetch_events", self.db.runWithConnection, self._do_fetch + "fetch_events", self.db_pool.runWithConnection, self._do_fetch ) logger.debug("Loading %d events: %s", len(events), events) @@ -889,7 +883,7 @@ def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -924,7 +918,7 @@ def have_seen_events_txn(txn, chunk): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -953,7 +947,7 @@ def get_total_state_event_counts(self, room_id): Returns: Deferred[int] """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_total_state_event_counts", self._get_total_state_event_counts_txn, room_id, @@ -978,7 +972,7 @@ def get_current_state_event_counts(self, room_id): Returns: Deferred[int] """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, @@ -1043,7 +1037,7 @@ def get_all_new_forward_event_rows(txn): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) @@ -1077,7 +1071,7 @@ def get_ex_outlier_stream_rows_txn(txn): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) @@ -1151,7 +1145,7 @@ def get_all_new_backfill_event_rows(txn): return new_event_updates, upper_bound, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @@ -1199,7 +1193,7 @@ def get_deltas_for_stream_id_txn(txn, stream_id): # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows = await self.db.runInteraction( + rows = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) # type: List[Tuple] @@ -1222,7 +1216,7 @@ def get_deltas_for_stream_id_txn(txn, stream_id): # stream id. let's run the query again, without a row limit, but for # just one stream id. to_token += 1 - rows = await self.db.runInteraction( + rows = await self.db_pool.runInteraction( "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) @@ -1317,7 +1311,7 @@ def get_all_new_events_txn(txn): backward_ex_outliers, ) - return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) + return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream @@ -1328,7 +1322,7 @@ async def is_event_after(self, event_id1, event_id2): @cachedInlineCallbacks(max_entries=5000) def get_event_ordering(self, event_id): - res = yield self.db.simple_select_one( + res = yield self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1360,88 +1354,10 @@ def get_next_event_to_expire_txn(txn): return txn.fetchone() - return self.db.runInteraction( + return self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - @cached(tree=True, cache_context=True) - async def get_unread_message_count_for_user( - self, room_id: str, user_id: str, cache_context: _CacheContext, - ) -> int: - """Retrieve the count of unread messages for the given room and user. - - Args: - room_id: The ID of the room to count unread messages in. - user_id: The ID of the user to count unread messages for. - - Returns: - The number of unread messages for the given user in the given room. - """ - with Measure(self._clock, "get_unread_message_count_for_user"): - last_read_event_id = await self.get_last_receipt_event_id_for_user( - user_id=user_id, - room_id=room_id, - receipt_type="m.read", - on_invalidate=cache_context.invalidate, - ) - - return await self.db.runInteraction( - "get_unread_message_count_for_user", - self._get_unread_message_count_for_user_txn, - user_id, - room_id, - last_read_event_id, - ) - - def _get_unread_message_count_for_user_txn( - self, - txn: Cursor, - user_id: str, - room_id: str, - last_read_event_id: Optional[str], - ) -> int: - if last_read_event_id: - # Get the stream ordering for the last read event. - stream_ordering = self.db.simple_select_one_onecol_txn( - txn=txn, - table="events", - keyvalues={"room_id": room_id, "event_id": last_read_event_id}, - retcol="stream_ordering", - ) - else: - # If there's no read receipt for that room, it probably means the user hasn't - # opened it yet, in which case use the stream ID of their join event. - # We can't just set it to 0 otherwise messages from other local users from - # before this user joined will be counted as well. - txn.execute( - """ - SELECT stream_ordering FROM local_current_membership - LEFT JOIN events USING (event_id, room_id) - WHERE membership = 'join' - AND user_id = ? - AND room_id = ? - """, - (user_id, room_id), - ) - row = txn.fetchone() - - if row is None: - return 0 - - stream_ordering = row[0] - - # Count the messages that qualify as unread after the stream ordering we've just - # retrieved. - sql = """ - SELECT COUNT(*) FROM events - WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread - """ - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - return row[0] if row else 0 - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/databases/main/filtering.py similarity index 89% rename from synapse/storage/data_stores/main/filtering.py rename to synapse/storage/databases/main/filtering.py index 342d6622a458..45a1760170bc 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,12 +17,12 @@ from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): + @cached(num_args=2) + async def get_user_filter(self, user_localpart, filter_id): # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -30,7 +30,7 @@ def get_user_filter(self, user_localpart, filter_id): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.db.simple_select_one_onecol( + def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -71,4 +71,4 @@ def _do_txn(txn): return filter_id - return self.db.runInteraction("add_user_filter", _do_txn) + return self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/databases/main/group_server.py similarity index 87% rename from synapse/storage/data_stores/main/group_server.py rename to synapse/storage/databases/main/group_server.py index 01ff561e1a61..380db3a3f34e 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,14 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple - -from canonicaljson import json - -from twisted.internet import defer +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict +from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -31,7 +29,7 @@ class GroupServerWorkerStore(SQLBaseStore): def get_group(self, group_id): - return self.db.simple_select_one( + return self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -53,7 +51,7 @@ def get_users_in_group(self, group_id, include_private=False): if not include_private: keyvalues["is_public"] = True - return self.db.simple_select_list( + return self.db_pool.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -63,7 +61,7 @@ def get_users_in_group(self, group_id, include_private=False): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -117,7 +115,9 @@ def _get_rooms_in_group_txn(txn): for room_id, is_public in txn ] - return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) + return self.db_pool.runInteraction( + "get_rooms_in_group", _get_rooms_in_group_txn + ) def get_rooms_for_summary_by_category( self, group_id: str, include_private: bool = False, @@ -205,13 +205,12 @@ def _get_rooms_for_summary_txn(txn): return rooms, categories - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_rooms_for_summary", _get_rooms_for_summary_txn ) - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self.db.simple_select_list( + async def get_group_categories(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -226,9 +225,8 @@ def get_group_categories(self, group_id): for row in rows } - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self.db.simple_select_one( + async def get_group_category(self, group_id, category_id): + category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -239,9 +237,8 @@ def get_group_category(self, group_id, category_id): return category - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self.db.simple_select_list( + async def get_group_roles(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -256,9 +253,8 @@ def get_group_roles(self, group_id): for row in rows } - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self.db.simple_select_one( + async def get_group_role(self, group_id, role_id): + role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -277,7 +273,7 @@ def get_local_groups_for_room(self, room_id): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -341,12 +337,12 @@ def _get_users_for_summary_txn(txn): return users, roles - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) def is_user_in_group(self, user_id, group_id): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -355,7 +351,7 @@ def is_user_in_group(self, user_id, group_id): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -366,7 +362,7 @@ def is_user_admin_in_group(self, group_id, user_id): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -389,7 +385,7 @@ def get_users_membership_info_in_group(self, group_id, user_id): """ def _get_users_membership_in_group_txn(txn): - row = self.db.simple_select_one_txn( + row = self.db_pool.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -404,7 +400,7 @@ def _get_users_membership_in_group_txn(txn): "is_privileged": row["is_admin"], } - row = self.db.simple_select_one_onecol_txn( + row = self.db_pool.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -417,14 +413,14 @@ def _get_users_membership_in_group_txn(txn): return {} - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -441,18 +437,17 @@ def _get_attestations_need_renewals_txn(txn): WHERE valid_until_ms <= ? """ txn.execute(sql, (valid_until_ms,)) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.db.simple_select_one( + row = await self.db_pool.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -467,7 +462,7 @@ def get_remote_attestation(self, group_id, user_id): return None def get_joined_groups(self, user_id): - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -494,17 +489,17 @@ def _get_all_groups_for_user_txn(txn): for row in txn ] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) - def get_groups_changes_for_user(self, user_id, from_token, to_token): + async def get_groups_changes_for_user(self, user_id, from_token, to_token): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_entity_changed( user_id, from_token ) if not has_changed: - return defer.succeed([]) + return [] def _get_groups_changes_for_user_txn(txn): sql = """ @@ -524,7 +519,7 @@ def _get_groups_changes_for_user_txn(txn): for group_id, membership, gtype, content_json in txn ] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -579,7 +574,7 @@ def _get_all_groups_changes_txn(txn): return updates, upto_token, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_groups_changes", _get_all_groups_changes_txn ) @@ -592,7 +587,7 @@ def set_group_join_policy(self, group_id, join_policy): * "invite" * "open" """ - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -600,7 +595,7 @@ def set_group_join_policy(self, group_id, join_policy): ) def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db.runInteraction( + return self.db_pool.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -624,7 +619,7 @@ def _add_room_to_summary_txn( an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self.db.simple_select_one_onecol_txn( + room_in_group = self.db_pool.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -637,7 +632,7 @@ def _add_room_to_summary_txn( if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self.db.simple_select_one_onecol_txn( + cat_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -648,7 +643,7 @@ def _add_room_to_summary_txn( raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self.db.simple_select_one_onecol_txn( + cat_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -668,7 +663,7 @@ def _add_room_to_summary_txn( (group_id, category_id, group_id, category_id), ) - existing = self.db.simple_select_one_txn( + existing = self.db_pool.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -701,7 +696,7 @@ def _add_room_to_summary_txn( to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -715,7 +710,7 @@ def _add_room_to_summary_txn( if is_public is None: is_public = True - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -731,7 +726,7 @@ def remove_room_from_summary(self, group_id, room_id, category_id): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self.db.simple_delete( + return self.db_pool.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -750,14 +745,14 @@ def upsert_group_category(self, group_id, category_id, profile, is_public): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True else: update_values["is_public"] = is_public - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -766,7 +761,7 @@ def upsert_group_category(self, group_id, category_id, profile, is_public): ) def remove_group_category(self, group_id, category_id): - return self.db.simple_delete( + return self.db_pool.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -781,14 +776,14 @@ def upsert_group_role(self, group_id, role_id, profile, is_public): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True else: update_values["is_public"] = is_public - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -797,14 +792,14 @@ def upsert_group_role(self, group_id, role_id, profile, is_public): ) def remove_group_role(self, group_id, role_id): - return self.db.simple_delete( + return self.db_pool.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db.runInteraction( + return self.db_pool.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -828,7 +823,7 @@ def _add_user_to_summary_txn( an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self.db.simple_select_one_onecol_txn( + user_in_group = self.db_pool.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -841,7 +836,7 @@ def _add_user_to_summary_txn( if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self.db.simple_select_one_onecol_txn( + role_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -852,7 +847,7 @@ def _add_user_to_summary_txn( raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self.db.simple_select_one_onecol_txn( + role_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -872,7 +867,7 @@ def _add_user_to_summary_txn( (group_id, role_id, group_id, role_id), ) - existing = self.db.simple_select_one_txn( + existing = self.db_pool.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -901,7 +896,7 @@ def _add_user_to_summary_txn( to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -915,7 +910,7 @@ def _add_user_to_summary_txn( if is_public is None: is_public = True - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_summary_users", values={ @@ -931,7 +926,7 @@ def remove_user_from_summary(self, group_id, user_id, role_id): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self.db.simple_delete( + return self.db_pool.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -940,7 +935,7 @@ def remove_user_from_summary(self, group_id, user_id, role_id): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self.db.simple_insert( + return self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -970,7 +965,7 @@ def add_user_to_group( """ def _add_user_to_group_txn(txn): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_users", values={ @@ -981,14 +976,14 @@ def _add_user_to_group_txn(txn): }, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -998,60 +993,60 @@ def _add_user_to_group_txn(txn): }, ) if remote_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_remote", values={ "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) - return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) + return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) def add_room_to_group(self, group_id, room_id, is_public): - return self.db.simple_insert( + return self.db_pool.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.db.simple_update( + return self.db_pool.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -1060,67 +1055,67 @@ def update_room_in_group_visibility(self, group_id, room_id, is_public): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, desc="update_group_publicity", ) - @defer.inlineCallbacks - def register_user_group_membership( + async def register_user_group_membership( self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): + group_id: str, + user_id: str, + membership: str, + is_admin: bool = False, + content: JsonDict = {}, + local_attestation: Optional[dict] = None, + remote_attestation: Optional[dict] = None, + is_publicised: bool = False, + ) -> int: """Registers that a local user is a member of a (local or remote) group. Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter + group_id: The group the member is being added to. + user_id: THe user ID to add to the group. + membership: The type of group membership. + is_admin: Whether the user should be added as a group admin. + content: Content of the membership, e.g. includes the inviter if the user has been invited. - local_attestation (dict): If remote group then store the fact that we + local_attestation: If remote group then store the fact that we have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote + remote_attestation: If remote group then store the remote attestation from the group, else None. + is_publicised: Whether this should be publicised. """ def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="local_group_membership", values={ @@ -1129,11 +1124,11 @@ def _register_user_group_membership_txn(txn, next_id): "is_admin": is_admin, "membership": membership, "is_publicised": is_publicised, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="local_group_updates", values={ @@ -1141,7 +1136,7 @@ def _register_user_group_membership_txn(txn, next_id): "group_id": group_id, "user_id": user_id, "type": "membership", - "content": json.dumps( + "content": json_encoder.encode( {"membership": membership, "content": content} ), }, @@ -1152,7 +1147,7 @@ def _register_user_group_membership_txn(txn, next_id): if membership == "join": if local_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -1162,23 +1157,23 @@ def _register_user_group_membership_txn(txn, next_id): }, ) if remote_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_remote", values={ "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) else: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -1187,18 +1182,17 @@ def _register_user_group_membership_txn(txn, next_id): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.db.runInteraction( + res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, ) return res - @defer.inlineCallbacks - def create_group( + async def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self.db.simple_insert( + ) -> None: + await self.db_pool.simple_insert( table="groups", values={ "group_id": group_id, @@ -1211,9 +1205,8 @@ def create_group( desc="create_group", ) - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self.db.simple_update_one( + async def update_group_profile(self, group_id, profile): + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1223,7 +1216,7 @@ def update_group_profile(self, group_id, profile): def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1233,12 +1226,12 @@ def update_attestation_renewal(self, group_id, user_id, attestation): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation), + "attestation_json": json_encoder.encode(attestation), }, desc="update_remote_attestion", ) @@ -1252,7 +1245,7 @@ def remove_attestation_renewal(self, group_id, user_id): group_id (str) user_id (str) """ - return self.db.simple_delete( + return self.db_pool.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1288,8 +1281,8 @@ def _delete_group_txn(txn): ] for table in tables: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) - return self.db.runInteraction("delete_group", _delete_group_txn) + return self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/databases/main/keys.py similarity index 95% rename from synapse/storage/data_stores/main/keys.py rename to synapse/storage/databases/main/keys.py index 4e1642a27a59..384e9c5eb0f1 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -86,7 +86,7 @@ def _txn(txn): _get_keys(txn, batch) return keys - return self.db.runInteraction("get_server_verify_keys", _txn) + return self.db_pool.runInteraction("get_server_verify_keys", _txn) def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): """Stores NACL verification keys for remote servers. @@ -121,9 +121,9 @@ def _invalidate(res): f((i,)) return res - return self.db.runInteraction( + return self.db_pool.runInteraction( "store_server_verify_keys", - self.db.simple_upsert_many_txn, + self.db_pool.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -151,7 +151,7 @@ def store_server_keys_json( ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -190,7 +190,7 @@ def _get_server_keys_json_txn(txn): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, @@ -205,4 +205,6 @@ def _get_server_keys_json_txn(txn): results[(server_name, key_id, from_server)] = rows return results - return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) + return self.db_pool.runInteraction( + "get_server_keys_json", _get_server_keys_json_txn + ) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/databases/main/media_repository.py similarity index 89% rename from synapse/storage/data_stores/main/media_repository.py rename to synapse/storage/databases/main/media_repository.py index 15bc13cbd0ee..80fc1cd0092a 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryBackgroundUpdateStore, self).__init__( database, db_conn, hs ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", @@ -34,7 +34,7 @@ def __init__(self, database: Database, db_conn, hs): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) def get_local_media(self, media_id): @@ -42,7 +42,7 @@ def get_local_media(self, media_id): Returns: None if the media_id doesn't exist. """ - return self.db.simple_select_one( + return self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -67,7 +67,7 @@ def store_local_media( user_id, url_cache=None, ): - return self.db.simple_insert( + return self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, @@ -83,7 +83,7 @@ def store_local_media( def mark_local_media_as_safe(self, media_id: str): """Mark a local media as safe from quarantining.""" - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="local_media_repository", keyvalues={"media_id": media_id}, updatevalues={"safe_from_quarantine": True}, @@ -136,12 +136,12 @@ def get_url_cache_txn(txn): ) ) - return self.db.runInteraction("get_url_cache", get_url_cache_txn) + return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.db.simple_insert( + return self.db_pool.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -156,7 +156,7 @@ def store_url_cache( ) def get_local_media_thumbnails(self, media_id): - return self.db.simple_select_list( + return self.db_pool.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -178,7 +178,7 @@ def store_local_thumbnail( thumbnail_method, thumbnail_length, ): - return self.db.simple_insert( + return self.db_pool.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -192,7 +192,7 @@ def store_local_thumbnail( ) def get_cached_remote_media(self, origin, media_id): - return self.db.simple_select_one( + return self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -217,7 +217,7 @@ def store_cached_remote_media( upload_name, filesystem_id, ): - return self.db.simple_insert( + return self.db_pool.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -262,12 +262,12 @@ def update_cache_txn(txn): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.db.runInteraction( + return self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn ) def get_remote_media_thumbnails(self, origin, media_id): - return self.db.simple_select_list( + return self.db_pool.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -292,7 +292,7 @@ def store_remote_media_thumbnail( thumbnail_method, thumbnail_length, ): - return self.db.simple_insert( + return self.db_pool.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -314,24 +314,26 @@ def get_remote_media_before(self, before_ts): " WHERE last_access_ts < ?" ) - return self.db.execute( - "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts + return self.db_pool.execute( + "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) + return self.db_pool.runInteraction( + "delete_remote_media", delete_remote_media_txn + ) def get_expired_url_cache(self, now_ts): sql = ( @@ -345,7 +347,7 @@ def _get_expired_url_cache_txn(txn): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_expired_url_cache", _get_expired_url_cache_txn ) @@ -358,7 +360,9 @@ async def delete_url_cache(self, media_ids): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) + return await self.db_pool.runInteraction( + "delete_url_cache", _delete_url_cache_txn + ) def get_url_cache_media_before(self, before_ts): sql = ( @@ -372,7 +376,7 @@ def _get_url_cache_media_before_txn(txn): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) @@ -389,6 +393,6 @@ def _delete_url_cache_media_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/databases/main/metrics.py similarity index 83% rename from synapse/storage/data_stores/main/metrics.py rename to synapse/storage/databases/main/metrics.py index dad5bbc60261..686052bd83c0 100644 --- a/synapse/storage/data_stores/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -15,15 +15,13 @@ import typing from collections import Counter -from twisted.internet import defer - from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.event_push_actions import ( +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.database import Database class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): @@ -31,7 +29,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. @@ -66,11 +64,10 @@ def fetch(txn): ) return txn.fetchall() - res = await self.db.runInteraction("read_forward_extremities", fetch) + res = await self.db_pool.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = Counter([x[0] for x in res]) - @defer.inlineCallbacks - def count_daily_messages(self): + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. @@ -88,11 +85,9 @@ def _count_messages(txn): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_messages", _count_messages) - return ret + return await self.db_pool.runInteraction("count_messages", _count_messages) - @defer.inlineCallbacks - def count_daily_sent_messages(self): + async def count_daily_sent_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then thats your own fault. @@ -109,11 +104,11 @@ def _count_messages(txn): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) - return ret + return await self.db_pool.runInteraction( + "count_daily_sent_messages", _count_messages + ) - @defer.inlineCallbacks - def count_daily_active_rooms(self): + async def count_daily_active_rooms(self): def _count(txn): sql = """ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events @@ -124,5 +119,4 @@ def _count(txn): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_daily_active_rooms", _count) - return ret + return await self.db_pool.runInteraction("count_daily_active_rooms", _count) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py similarity index 90% rename from synapse/storage/data_stores/main/monthly_active_users.py rename to synapse/storage/databases/main/monthly_active_users.py index e459cf49a0b1..e71cdd2cb4e2 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -15,10 +15,8 @@ import logging from typing import List -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database, make_in_list_sql_clause +from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -29,7 +27,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -48,7 +46,7 @@ def _count_users(txn): (count,) = txn.fetchone() return count - return self.db.runInteraction("count_users", _count_users) + return self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) def get_monthly_active_count_by_service(self): @@ -76,7 +74,9 @@ def _count_users_by_service(txn): result = txn.fetchall() return dict(result) - return self.db.runInteraction("count_users_by_service", _count_users_by_service) + return self.db_pool.runInteraction( + "count_users_by_service", _count_users_by_service + ) async def get_registered_reserved_users(self) -> List[str]: """Of the reserved threepids defined in config, retrieve those that are associated @@ -109,7 +109,7 @@ def user_last_seen_monthly_active(self, user_id): """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", @@ -119,7 +119,7 @@ def user_last_seen_monthly_active(self, user_id): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._limit_usage_by_mau = hs.config.limit_usage_by_mau @@ -128,7 +128,7 @@ def __init__(self, database: Database, db_conn, hs): # Do not add more reserved users than the total allowable number # cur = LoggingTransaction( - self.db.new_transaction( + self.db_pool.new_transaction( db_conn, "initialise_mau_threepids", [], @@ -162,7 +162,7 @@ def _initialise_reserved_users(self, txn, threepids): is_support = self.is_support_user_txn(txn, user_id) if not is_support: # We do this manually here to avoid hitting #6791 - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -246,20 +246,16 @@ def _reap_users(txn, reserved_users): self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) reserved_users = await self.get_registered_reserved_users() - await self.db.runInteraction( + await self.db_pool.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) - @defer.inlineCallbacks - def upsert_monthly_active_user(self, user_id): + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server Args: - user_id (str): user to add/update - - Returns: - Deferred + user_id: user to add/update """ # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -269,11 +265,11 @@ def upsert_monthly_active_user(self, user_id): # _initialise_reserved_users reasoning that it would be very strange to # include a support user in this context. - is_support = yield self.is_support_user(user_id) + is_support = await self.is_support_user(user_id) if is_support: return - yield self.db.runInteraction( + await self.db_pool.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -303,7 +299,7 @@ def upsert_monthly_active_user_txn(self, txn, user_id): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self.db.simple_upsert_txn( + is_insert = self.db_pool.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -320,8 +316,7 @@ def upsert_monthly_active_user_txn(self, txn, user_id): return is_insert - @defer.inlineCallbacks - def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -330,14 +325,14 @@ def populate_monthly_active_users(self, user_id): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = yield self.is_guest(user_id) + is_guest = await self.is_guest(user_id) if is_guest: return - is_trial = yield self.is_trial_user(user_id) + is_trial = await self.is_trial_user(user_id) if is_trial: return - last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + last_seen_timestamp = await self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() # We want to reduce to the total number of db writes, and are happy @@ -350,10 +345,10 @@ def populate_monthly_active_users(self, user_id): # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. if not self._limit_usage_by_mau: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) else: - count = yield self.get_monthly_active_count() + count = await self.get_monthly_active_count() if count < self._max_mau_value: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/databases/main/openid.py similarity index 91% rename from synapse/storage/data_stores/main/openid.py rename to synapse/storage/databases/main/openid.py index cc21437e920e..dcd1ff911a20 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -3,7 +3,7 @@ class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db.simple_insert( + return self.db_pool.simple_insert( table="open_id_tokens", values={ "token": token, @@ -28,6 +28,6 @@ def get_user_id_for_token_txn(txn): else: return rows[0][0] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/databases/main/presence.py similarity index 94% rename from synapse/storage/data_stores/main/presence.py rename to synapse/storage/databases/main/presence.py index 757461261936..59ba12820ac3 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,6 @@ from typing import List, Tuple -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList @@ -24,14 +22,13 @@ class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): + async def update_presence(self, presence_states): stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) with stream_ordering_manager as stream_orderings: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, @@ -48,7 +45,7 @@ def _update_presence_txn(self, txn, stream_orderings, presence_states): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -124,7 +121,7 @@ def get_all_presence_updates_txn(txn): return updates, upper_bound, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) @@ -139,7 +136,7 @@ def _get_presence_for_user(self, user_id): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -165,7 +162,7 @@ def get_current_presence_token(self): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self.db.simple_insert( + return self.db_pool.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -176,7 +173,7 @@ def allow_presence_visible(self, observed_localpart, observer_userid): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.db.simple_delete_one( + return self.db_pool.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/databases/main/profile.py similarity index 83% rename from synapse/storage/data_stores/main/profile.py rename to synapse/storage/databases/main/profile.py index bfc9369f0b58..b8261357d489 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.roommember import ProfileInfo +from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart): try: - profile = yield self.db.simple_select_one( + profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +39,7 @@ def get_profileinfo(self, user_localpart): ) def get_profile_displayname(self, user_localpart): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +47,7 @@ def get_profile_displayname(self, user_localpart): ) def get_profile_avatar_url(self, user_localpart): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +55,7 @@ def get_profile_avatar_url(self, user_localpart): ) def get_from_remote_profile_cache(self, user_id): - return self.db.simple_select_one( + return self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +64,12 @@ def get_from_remote_profile_cache(self, user_id): ) def create_profile(self, user_localpart): - return self.db.simple_insert( + return self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +77,7 @@ def set_profile_displayname(self, user_localpart, new_displayname): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +92,7 @@ def add_remote_profile_cache(self, user_id, displayname, avatar_url): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +104,7 @@ def add_remote_profile_cache(self, user_id, displayname, avatar_url): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db.simple_update( + return self.db_pool.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, updatevalues={ @@ -118,14 +115,13 @@ def update_remote_profile_cache(self, user_id, displayname, avatar_url): desc="update_remote_profile_cache", ) - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id): """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + subscribed = await self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.db.simple_delete( + await self.db_pool.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -144,18 +140,17 @@ def _get_remote_profile_cache_entries_that_expire_txn(txn): txn.execute(sql, (last_checked,)) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +161,7 @@ def is_subscribed_remote_profile_for_user(self, user_id): if res: return True - res = yield self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/databases/main/purge_events.py similarity index 98% rename from synapse/storage/data_stores/main/purge_events.py rename to synapse/storage/databases/main/purge_events.py index b53fe35c338b..3526b6fd6696 100644 --- a/synapse/storage/data_stores/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -18,7 +18,7 @@ from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def purge_history(self, room_id, token, delete_local_events): deleted events. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -293,7 +293,7 @@ def purge_room(self, room_id): Deferred[List[int]]: The list of state groups to delete. """ - return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) + return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/databases/main/push_rule.py similarity index 81% rename from synapse/storage/data_stores/main/push_rule.py rename to synapse/storage/databases/main/push_rule.py index c22924810141..6562db5c2bde 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,28 +18,27 @@ import logging from typing import List, Tuple, Union -from canonicaljson import json - from twisted.internet import defer from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.pusher import PusherWorkerStore -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -def _load_rules(rawrules, enabled_map): +def _load_rules(rawrules, enabled_map, use_new_defaults=False): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -49,7 +48,7 @@ def _load_rules(rawrules, enabled_map): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) + rules = list(list_with_base_rules(ruleslist, use_new_defaults)) for i, rule in enumerate(rules): rule_id = rule["rule_id"] @@ -79,7 +78,7 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: @@ -91,7 +90,7 @@ def __init__(self, database: Database, db_conn, hs): db_conn, "push_rules_stream", "stream_id" ) - push_rules_prefill, push_rules_id = self.db.get_cache_dict( + push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -105,6 +104,8 @@ def __init__(self, database: Database, db_conn, hs): prefilled_cache=push_rules_prefill, ) + self._users_new_default_push_rules = hs.config.users_new_default_push_rules + @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. @@ -116,7 +117,7 @@ def get_max_push_rules_stream_id(self): @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self.db.simple_select_list( + rows = yield self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -134,13 +135,15 @@ def get_push_rules_for_user(self, user_id): enabled_map = yield self.get_push_rules_enabled_for_user(user_id) - rules = _load_rules(rows, enabled_map) + use_new_defaults = user_id in self._users_new_default_push_rules + + rules = _load_rules(rows, enabled_map, use_new_defaults) return rules @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self.db.simple_select_list( + results = yield self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -162,7 +165,7 @@ def have_push_rules_changed_txn(txn): (count,) = txn.fetchone() return bool(count) - return self.db.runInteraction( + return self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @@ -178,7 +181,7 @@ def bulk_get_push_rules(self, user_ids): results = {user_id: [] for user_id in user_ids} - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -194,7 +197,11 @@ def bulk_get_push_rules(self, user_ids): enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + use_new_defaults = user_id in self._users_new_default_push_rules + + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), use_new_defaults, + ) return results @@ -249,81 +256,6 @@ def copy_push_rules_from_room_to_room_for_user( ): yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = { - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - } - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = { - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - } - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user - @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", @@ -336,7 +268,7 @@ def bulk_get_push_rules_enabled(self, user_ids): results = {user_id: {} for user_id in user_ids} - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -394,7 +326,7 @@ def get_all_push_rule_updates_txn(txn): return updates, upper_bound, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) @@ -411,12 +343,12 @@ def add_push_rule( before=None, after=None, ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) + conditions_json = json_encoder.encode(conditions) + actions_json = json_encoder.encode(actions) with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -430,7 +362,7 @@ def add_push_rule( after, ) else: - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -461,7 +393,7 @@ def _add_push_rule_relative_txn( relative_to_rule = before or after - res = self.db.simple_select_one_txn( + res = self.db_pool.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -584,7 +516,7 @@ def _upsert_push_rule_txn( # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="push_rules", values={ @@ -627,7 +559,7 @@ def delete_push_rule(self, user_id, rule_id): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self.db.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -637,7 +569,7 @@ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, @@ -648,7 +580,7 @@ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): def set_push_rule_enabled(self, user_id, rule_id, enabled): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -662,7 +594,7 @@ def _set_push_rule_enabled_txn( self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -681,7 +613,7 @@ def _set_push_rule_enabled_txn( @defer.inlineCallbacks def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) + actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): if is_default_rule: @@ -702,7 +634,7 @@ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): update_stream=False, ) else: - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -721,7 +653,7 @@ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, @@ -741,7 +673,7 @@ def _insert_push_rules_update_txn( if data is not None: values.update(data) - self.db.simple_insert_txn(txn, "push_rules_stream", values=values) + self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/databases/main/pusher.py similarity index 92% rename from synapse/storage/data_stores/main/pusher.py rename to synapse/storage/databases/main/pusher.py index e18f1ca87c86..b5200fbe79cc 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -50,7 +50,7 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self.db.simple_select_one_onecol( + ret = yield self.db_pool.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -63,7 +63,7 @@ def get_pushers_by_user_id(self, user_id): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self.db.simple_select_list( + ret = yield self.db_pool.simple_select_list( "pushers", keyvalues, [ @@ -91,11 +91,11 @@ def get_pushers_by(self, keyvalues): def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.db.runInteraction("get_all_pushers", get_pushers) + rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers) return rows async def get_all_updated_pushers_rows( @@ -160,7 +160,7 @@ def get_all_updated_pushers_rows_txn(txn): return updates, upper_bound, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) @@ -176,7 +176,7 @@ def get_if_user_has_pusher(self, user_id): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -193,7 +193,7 @@ def get_if_users_have_pushers(self, user_ids): def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self.db.simple_update_one( + yield self.db_pool.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -216,7 +216,7 @@ def update_pusher_last_stream_ordering_and_success( Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self.db.simple_update( + updated = yield self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -230,7 +230,7 @@ def update_pusher_last_stream_ordering_and_success( @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.db.simple_update( + yield self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -239,7 +239,7 @@ def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self.db.simple_select_list( + res = yield self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -259,7 +259,7 @@ def get_throttle_params_by_room(self, pusher_id): def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry - yield self.db.simple_upsert( + yield self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, @@ -291,7 +291,7 @@ def add_pusher( with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry - yield self.db.simple_upsert( + yield self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -316,7 +316,7 @@ def add_pusher( if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.db.runInteraction( + yield self.db_pool.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, @@ -330,7 +330,7 @@ def delete_pusher_txn(txn, stream_id): txn, self.get_if_user_has_pusher, (user_id,) ) - self.db.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -339,7 +339,7 @@ def delete_pusher_txn(txn, stream_id): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -351,4 +351,6 @@ def delete_pusher_txn(txn, stream_id): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) + yield self.db_pool.runInteraction( + "delete_pusher", delete_pusher_txn, stream_id + ) diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/databases/main/receipts.py similarity index 93% rename from synapse/storage/data_stores/main/receipts.py rename to synapse/storage/databases/main/receipts.py index 1d723f2d347e..1920a8a152af 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -18,13 +18,12 @@ import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -41,7 +40,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( @@ -64,7 +63,7 @@ def get_users_with_read_receipts_in_room(self, room_id): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self.db.simple_select_list( + return self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -73,7 +72,7 @@ def get_receipts_for_room(self, room_id, receipt_type): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -87,7 +86,7 @@ def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self.db.simple_select_list( + rows = yield self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -111,7 +110,9 @@ def f(txn): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) + rows = yield self.db_pool.runInteraction( + "get_receipts_for_user_with_orderings", f + ) return { row[0]: { "event_id": row[1], @@ -190,11 +191,11 @@ def f(txn): txn.execute(sql, (room_id, to_key)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) return rows - rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) + rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -240,9 +241,9 @@ def f(txn): txn.execute(sql + clause, [to_key] + list(args)) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - txn_results = yield self.db.runInteraction( + txn_results = yield self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) @@ -288,7 +289,7 @@ def _get_users_sent_receipts_between_txn(txn): return [r[0] for r in txn] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn ) @@ -340,7 +341,7 @@ def get_all_updated_receipts_txn(txn): return updates, upper_bound, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) @@ -371,7 +372,7 @@ def _invalidate_get_users_with_receipts_in_room( class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( @@ -393,7 +394,7 @@ def insert_linearized_receipt_txn( otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self.db.simple_select_one_txn( + res = self.db_pool.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -446,7 +447,7 @@ def insert_linearized_receipt_txn( (user_id, room_id, receipt_type), ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", keyvalues={ @@ -457,7 +458,7 @@ def insert_linearized_receipt_txn( values={ "stream_id": stream_id, "event_id": event_id, - "data": json.dumps(data), + "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -506,13 +507,13 @@ def graph_to_linear(txn): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.db.runInteraction( + linearized_event_id = yield self.db_pool.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - event_ts = yield self.db.runInteraction( + event_ts = yield self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -541,7 +542,7 @@ def graph_to_linear(txn): return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.db.runInteraction( + return self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -567,7 +568,7 @@ def insert_graph_receipt_txn( self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -576,14 +577,14 @@ def insert_graph_receipt_txn( "user_id": user_id, }, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="receipts_graph", values={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, - "event_ids": json.dumps(event_ids), - "data": json.dumps(data), + "event_ids": json_encoder.encode(event_ids), + "data": json_encoder.encode(data), }, ) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/databases/main/registration.py similarity index 83% rename from synapse/storage/data_stores/main/registration.py rename to synapse/storage/databases/main/registration.py index 27d2c5028c42..402ae25571d0 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,20 +17,19 @@ import logging import re -from typing import Optional +from typing import Dict, List, Optional -from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 @@ -38,7 +37,7 @@ class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -50,7 +49,7 @@ def __init__(self, database: Database, db_conn, hs): @cached() def get_user_by_id(self, user_id): - return self.db.simple_select_one( + return self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -69,19 +68,15 @@ def get_user_by_id(self, user_id): desc="get_user_by_id", ) - @defer.inlineCallbacks - def is_trial_user(self, user_id): + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config Args: - user_id (str) - - Returns: - Deferred[bool] + user_id: The user to check for trial status. """ - info = yield self.get_user_by_id(user_id) + info = await self.get_user_by_id(user_id) if not info: return False @@ -101,50 +96,51 @@ def get_user_by_access_token(self, token): including the keys `name`, `is_guest`, `device_id`, `token_id`, `valid_until_ms`. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) - @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user_id): + @cached() + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]: """Get the expiration timestamp for the account bearing a given user ID. Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). + None, if the account has no expiration timestamp, otherwise int + representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", allow_none=True, desc="get_expiration_ts_for_user", ) - return res - @defer.inlineCallbacks - def set_account_validity_for_user( - self, user_id, expiration_ts, email_sent, renewal_token=None - ): + async def set_account_validity_for_user( + self, + user_id: str, + expiration_ts: int, + email_sent: bool, + renewal_token: Optional[str] = None, + ) -> None: """Updates the account validity properties of the given account, with the given values. Args: - user_id (str): ID of the account to update properties for. - expiration_ts (int): New expiration date, as a timestamp in milliseconds + user_id: ID of the account to update properties for. + expiration_ts: New expiration date, as a timestamp in milliseconds since epoch. - email_sent (bool): True means a renewal email has been sent for this - account and there's no need to send another one for the current validity + email_sent: True means a renewal email has been sent for this account + and there's no need to send another one for the current validity period. - renewal_token (str): Renewal token the user can use to extend the validity + renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. """ def set_account_validity_for_user_txn(txn): - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -158,75 +154,69 @@ def set_account_validity_for_user_txn(txn): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.db.runInteraction( + await self.db_pool.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) - @defer.inlineCallbacks - def set_renewal_token_for_user(self, user_id, renewal_token): + async def set_renewal_token_for_user( + self, user_id: str, renewal_token: str + ) -> None: """Defines a renewal token for a given user. Args: - user_id (str): ID of the user to set the renewal token for. - renewal_token (str): Random unique string that will be used to renew the + user_id: ID of the user to set the renewal token for. + renewal_token: Random unique string that will be used to renew the user's account. Raises: StoreError: The provided token is already set for another user. """ - yield self.db.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, desc="set_renewal_token_for_user", ) - @defer.inlineCallbacks - def get_user_from_renewal_token(self, renewal_token): + async def get_user_from_renewal_token(self, renewal_token: str) -> str: """Get a user ID from a renewal token. Args: - renewal_token (str): The renewal token to perform the lookup with. + renewal_token: The renewal token to perform the lookup with. Returns: - defer.Deferred[str]: The ID of the user to which the token belongs. + The ID of the user to which the token belongs. """ - res = yield self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", desc="get_user_from_renewal_token", ) - return res - - @defer.inlineCallbacks - def get_renewal_token_for_user(self, user_id): + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. Args: - user_id (str): The user ID to lookup a token for. + user_id: The user ID to lookup a token for. Returns: - defer.Deferred[str]: The renewal token associated with this user ID. + The renewal token associated with this user ID. """ - res = yield self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", desc="get_renewal_token_for_user", ) - return res - - @defer.inlineCallbacks - def get_users_expiring_soon(self): + async def get_users_expiring_soon(self) -> List[Dict[str, int]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + A list of dictionaries mapping user ID to expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -236,58 +226,54 @@ def select_users_txn(txn, now_ms, renew_at): ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - res = yield self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), self.config.account_validity.renew_at, ) - return res - - @defer.inlineCallbacks - def set_renewal_mail_status(self, user_id, email_sent): + async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: """Sets or unsets the flag that indicates whether a renewal email has been sent to the user (and the user hasn't renewed their account yet). Args: - user_id (str): ID of the user to set/unset the flag for. - email_sent (bool): Flag which indicates whether a renewal email has been sent + user_id: ID of the user to set/unset the flag for. + email_sent: Flag which indicates whether a renewal email has been sent to this user. """ - yield self.db.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, desc="set_renewal_mail_status", ) - @defer.inlineCallbacks - def delete_account_validity_for_user(self, user_id): + async def delete_account_validity_for_user(self, user_id: str) -> None: """Deletes the entry for the given user in the account validity table, removing their expiration date and renewal token. Args: - user_id (str): ID of the user to remove from the account validity table. + user_id: ID of the user to remove from the account validity table. """ - yield self.db.simple_delete_one( + await self.db_pool.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", ) - async def is_server_admin(self, user): + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test + user: user ID of the user to test - Returns (bool): + Returns: true iff the user is a server admin, false otherwise. """ - res = await self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -307,14 +293,14 @@ def set_server_admin(self, user, admin): """ def set_server_admin_txn(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} ) self._invalidate_cache_and_stream( txn, self.get_user_by_id, (user.to_string(),) ) - return self.db.runInteraction("set_server_admin", set_server_admin_txn) + return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( @@ -326,43 +312,42 @@ def _query_for_auth(self, txn, token): ) txn.execute(sql, (token,)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if rows: return rows[0] return None - @cachedInlineCallbacks() - def is_real_user(self, user_id): + @cached() + async def is_real_user(self, user_id: str) -> bool: """Determines if the user is a real user, ie does not have a 'user_type'. Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user 'user_type' is null or empty string + True if user 'user_type' is null or empty string """ - res = yield self.db.runInteraction( + return await self.db_pool.runInteraction( "is_real_user", self.is_real_user_txn, user_id ) - return res @cached() - def is_support_user(self, user_id): + async def is_support_user(self, user_id: str) -> bool: """Determines if the user is of type UserTypes.SUPPORT Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT + True if user is of type UserTypes.SUPPORT """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) def is_real_user_txn(self, txn, user_id): - res = self.db.simple_select_one_onecol_txn( + res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -372,7 +357,7 @@ def is_real_user_txn(self, txn, user_id): return res is None def is_support_user_txn(self, txn, user_id): - res = self.db.simple_select_one_onecol_txn( + res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -391,7 +376,7 @@ def f(txn): txn.execute(sql, (user_id,)) return dict(txn) - return self.db.runInteraction("get_users_by_id_case_insensitive", f) + return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -405,7 +390,7 @@ async def get_user_by_external_id( Returns: str|None: the mxid of the user, or None if they are not known """ - return await self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -413,19 +398,17 @@ async def get_user_by_external_id( desc="get_user_by_external_id", ) - @defer.inlineCallbacks - def count_all_users(self): + async def count_all_users(self): """Counts all users registered on the homeserver.""" def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.db.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) def count_daily_user_type(self): """ @@ -456,10 +439,11 @@ def _count_daily_user_type(txn): results[row[0]] = row[1] return results - return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) + return self.db_pool.runInteraction( + "count_daily_user_type", _count_daily_user_type + ) - @defer.inlineCallbacks - def count_nonbridged_users(self): + async def count_nonbridged_users(self): def _count_users(txn): txn.execute( """ @@ -470,29 +454,26 @@ def _count_users(txn): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) - @defer.inlineCallbacks - def count_real_users(self): + async def count_real_users(self): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.db.runInteraction("count_real_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_real_users", _count_users) async def generate_user_id(self) -> str: """Generate a suitable localpart for a guest user Returns: a (hopefully) free localpart """ - next_id = await self.db.runInteraction( + next_id = await self.db_pool.runInteraction( "generate_user_id", self._user_id_seq.get_next_id_txn ) @@ -508,7 +489,7 @@ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[s Returns: The user ID or None if no user id/threepid mapping exists """ - user_id = await self.db.runInteraction( + user_id = await self.db_pool.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -524,7 +505,7 @@ def get_user_id_by_threepid_txn(self, txn, medium, address): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self.db.simple_select_one_txn( + ret = self.db_pool.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -535,26 +516,23 @@ def get_user_id_by_threepid_txn(self, txn, medium, address): return ret["user_id"] return None - @defer.inlineCallbacks - def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.db.simple_upsert( + async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - @defer.inlineCallbacks - def user_get_threepids(self, user_id): - ret = yield self.db.simple_select_list( + async def user_get_threepids(self, user_id): + return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], "user_get_threepids", ) - return ret def user_delete_threepid(self, user_id, medium, address): - return self.db.simple_delete( + return self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -567,7 +545,7 @@ def user_delete_threepids(self, user_id: str): user_id: The user id to delete all threepids of """ - return self.db.simple_delete( + return self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -589,7 +567,7 @@ def add_user_bound_threepid(self, user_id, medium, address, id_server): """ # We need to use an upsert, in case they user had already bound the # threepid - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -615,7 +593,7 @@ def user_get_bound_threepids(self, user_id): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.db.simple_select_list( + return self.db_pool.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -636,7 +614,7 @@ def remove_user_bound_threepid(self, user_id, medium, address, id_server): Returns: Deferred """ - return self.db.simple_delete( + return self.db_pool.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -659,25 +637,25 @@ def get_id_servers_user_bound(self, user_id, medium, address): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", desc="get_id_servers_user_bound", ) - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): + @cached() + async def get_user_deactivated_status(self, user_id: str) -> bool: """Retrieve the value for the `deactivated` property for the provided user. Args: - user_id (str): The ID of the user to retrieve the status for. + user_id: The ID of the user to retrieve the status for. Returns: - defer.Deferred(bool): The requested value. + True if the user was deactivated, false if the user is still active. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -744,13 +722,13 @@ def get_threepid_validation_session_txn(txn): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return None return rows[0] - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) @@ -764,37 +742,37 @@ def delete_threepid_session(self, session_id): """ def delete_threepid_session_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", @@ -804,18 +782,19 @@ def __init__(self, database: Database, db_conn, hs): # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.db.updates.register_noop_background_update("refresh_tokens_device_index") + self.db_pool.updates.register_noop_background_update( + "refresh_tokens_device_index" + ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - @defer.inlineCallbacks - def _background_update_set_deactivated_flag(self, progress, batch_size): + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ @@ -843,7 +822,7 @@ def _background_update_set_deactivated_flag_txn(txn): (last_user, batch_size), ) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return True, 0 @@ -857,7 +836,7 @@ def _background_update_set_deactivated_flag_txn(txn): logger.info("Marked %d rows as deactivated", rows_processed_nb) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) @@ -866,17 +845,18 @@ def _background_update_set_deactivated_flag_txn(txn): else: return False, len(rows) - end, nb_processed = yield self.db.runInteraction( + end, nb_processed = await self.db_pool.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self.db.updates._end_background_update("users_set_deactivated_flag") + await self.db_pool.updates._end_background_update( + "users_set_deactivated_flag" + ) return nb_processed - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): + async def _bg_user_threepids_grandfather(self, progress, batch_size): """We now track which identity servers a user binds their 3PID to, so we need to handle the case of existing bindings where we didn't track this. @@ -897,17 +877,17 @@ def _bg_user_threepids_grandfather_txn(txn): txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self.db.updates._end_background_update("user_threepids_grandfather") + await self.db_pool.updates._end_background_update("user_threepids_grandfather") return 1 class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity @@ -931,23 +911,26 @@ def start_cull(): hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + async def add_access_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + valid_until_ms: Optional[int], + ) -> None: """Adds an access token for the given user. Args: - user_id (str): The user ID. - token (str): The new access token to add. - device_id (str): ID of the device to associate with the access - token - valid_until_ms (int|None): when the token is valid until. None for - no expiry. + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the access token + valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. """ next_id = self._access_tokens_id_gen.get_next() - yield self.db.simple_insert( + await self.db_pool.simple_insert( "access_tokens", { "id": next_id, @@ -992,7 +975,7 @@ def register_user( Returns: Deferred """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "register_user", self._register_user, user_id, @@ -1026,7 +1009,7 @@ def _register_user( # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self.db.simple_select_one_txn( + self.db_pool.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1034,7 +1017,7 @@ def _register_user( allow_none=False, ) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1048,7 +1031,7 @@ def _register_user( }, ) else: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "users", values={ @@ -1091,7 +1074,6 @@ def _register_user( ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str @@ -1103,7 +1085,7 @@ def record_user_external_id( external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.db.simple_insert( + return self.db_pool.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1121,12 +1103,12 @@ def user_set_password_hash(self, user_id, password_hash): """ def user_set_password_hash_txn(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db.runInteraction( + return self.db_pool.runInteraction( "user_set_password_hash", user_set_password_hash_txn ) @@ -1143,7 +1125,7 @@ def user_set_consent_version(self, user_id, consent_version): """ def f(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1151,7 +1133,7 @@ def f(txn): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db.runInteraction("user_set_consent_version", f) + return self.db_pool.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): """Updates the user table to record that we have sent the user a server @@ -1167,7 +1149,7 @@ def user_set_consent_server_notice_sent(self, user_id, consent_version): """ def f(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1175,7 +1157,7 @@ def f(txn): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db.runInteraction("user_set_consent_server_notice_sent", f) + return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ @@ -1221,11 +1203,11 @@ def f(txn): return tokens_and_devices - return self.db.runInteraction("user_delete_access_tokens", f) + return self.db_pool.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): - self.db.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1233,11 +1215,11 @@ def f(txn): txn, self.get_user_by_access_token, (access_token,) ) - return self.db.runInteraction("delete_access_token", f) + return self.db_pool.runInteraction("delete_access_token", f) - @cachedInlineCallbacks() - def is_guest(self, user_id): - res = yield self.db.simple_select_one_onecol( + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1252,7 +1234,7 @@ def add_user_pending_deactivation(self, user_id): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.db.simple_insert( + return self.db_pool.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1265,7 +1247,7 @@ def del_user_pending_deactivation(self, user_id): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self.db.simple_delete( + return self.db_pool.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1276,7 +1258,7 @@ def get_user_pending_deactivation(self): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1306,7 +1288,7 @@ def validate_threepid_session(self, session_id, client_secret, token, current_ts # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self.db.simple_select_one_txn( + row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1324,7 +1306,7 @@ def validate_threepid_session_txn(txn): 400, "This client_secret does not match the provided session_id" ) - row = self.db.simple_select_one_txn( + row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1349,7 +1331,7 @@ def validate_threepid_session_txn(txn): ) # Looks good. Validate the session - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1359,7 +1341,7 @@ def validate_threepid_session_txn(txn): return next_link # Return next_link if it exists - return self.db.runInteraction( + return self.db_pool.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) @@ -1392,7 +1374,7 @@ def upsert_threepid_validation_session( if validated_at: insertion_values["validated_at"] = validated_at - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1430,7 +1412,7 @@ def start_or_continue_validation_session( def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1443,7 +1425,7 @@ def start_or_continue_validation_session_txn(txn): ) # Create a new validation token with this session ID - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1454,7 +1436,7 @@ def start_or_continue_validation_session_txn(txn): }, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) @@ -1469,22 +1451,23 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts): """ return txn.execute(sql, (ts,)) - return self.db.runInteraction( + return self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), ) - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: """Set the `deactivated` property for the provided user to the provided value. Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. """ - yield self.db.runInteraction( + await self.db_pool.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1492,7 +1475,7 @@ def set_user_deactivated_status(self, user_id, deactivated): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1501,9 +1484,9 @@ def set_user_deactivated_status_txn(self, txn, user_id, deactivated): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + txn.call_after(self.is_guest.invalidate, (user_id,)) - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): + async def _set_expiration_date_when_missing(self): """ Retrieves the list of registered users that don't have an expiration date, and adds an expiration date for each of them. @@ -1520,14 +1503,14 @@ def select_users_with_no_expiration_date_txn(txn): ) txn.execute(sql, []) - res = self.db.cursor_to_dict(txn) + res = self.db_pool.cursor_to_dict(txn) if res: for user in res: self.set_expiration_date_for_user_txn( txn, user["name"], use_delta=True ) - yield self.db.runInteraction( + await self.db_pool.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) @@ -1551,7 +1534,7 @@ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): expiration_ts, ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, "account_validity", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/databases/main/rejections.py similarity index 94% rename from synapse/storage/data_stores/main/rejections.py rename to synapse/storage/databases/main/rejections.py index 27e5a2084a20..cf9ba5120594 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -22,7 +22,7 @@ class RejectionsStore(SQLBaseStore): def get_rejection_reason(self, event_id): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/databases/main/relations.py similarity index 94% rename from synapse/storage/data_stores/main/relations.py rename to synapse/storage/databases/main/relations.py index 7d477f8d0111..a9ceffc20e3c 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,18 +14,20 @@ # limitations under the License. import logging +from typing import Optional import attr from synapse.api.constants import RelationTypes +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.stream import generate_pagination_where_clause +from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -129,7 +131,7 @@ def _get_recent_references_for_event_txn(txn): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @@ -223,22 +225,22 @@ def _get_aggregation_groups_for_event_txn(txn): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): + @cached() + async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. Correctly handles checking whether edits were allowed to happen. Args: - event_id (str): The original event ID + event_id: The original event ID Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. + The most recent edit, if any. """ # We only allow edits for `m.room.message` events that have the same sender @@ -268,15 +270,14 @@ def _get_applicable_edit_txn(txn): if row: return row[0] - edit_id = yield self.db.runInteraction( + edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) if not edit_id: - return + return None - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event + return await self.get_event(edit_id, allow_none=True) def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): """Check if a user has already annotated an event with the same key @@ -318,7 +319,7 @@ def _get_if_user_has_annotated_event(txn): return bool(txn.fetchone()) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/databases/main/room.py similarity index 93% rename from synapse/storage/data_stores/main/room.py rename to synapse/storage/databases/main/room.py index ab48052cdc9f..f4008e6221b4 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -27,8 +27,8 @@ from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.data_stores.main.search import SearchStore -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.search import SearchStore from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached @@ -73,7 +73,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -86,7 +86,7 @@ def get_room(self, room_id): Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db.simple_select_one( + return self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -118,7 +118,7 @@ def get_room_with_stats_txn(txn, room_id): txn.execute(sql, [room_id]) # Catch error if sql returns empty result to return "None" instead of an error try: - res = self.db.cursor_to_dict(txn)[0] + res = self.db_pool.cursor_to_dict(txn)[0] except IndexError: return None @@ -126,12 +126,12 @@ def get_room_with_stats_txn(txn, room_id): res["public"] = bool(res["public"]) return res - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) def get_public_room_ids(self): - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -188,7 +188,9 @@ def _count_public_rooms_txn(txn): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) + return self.db_pool.runInteraction( + "count_public_rooms", _count_public_rooms_txn + ) async def get_largest_public_rooms( self, @@ -320,21 +322,21 @@ async def get_largest_public_rooms( def _get_largest_public_rooms_txn(txn): txn.execute(sql, query_args) - results = self.db.cursor_to_dict(txn) + results = self.db_pool.cursor_to_dict(txn) if not forwards: results.reverse() return results - ret_val = await self.db.runInteraction( + ret_val = await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) return ret_val @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -502,7 +504,7 @@ def _get_rooms_paginate_txn(txn): room_count = txn.fetchone() return rooms, room_count[0] - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_paginate", _get_rooms_paginate_txn, ) @@ -519,7 +521,7 @@ async def get_ratelimit_for_user(self, user_id): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = await self.db.simple_select_one( + row = await self.db_pool.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -561,9 +563,9 @@ def get_retention_policy_for_room_txn(txn): (room_id,), ) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - ret = await self.db.runInteraction( + ret = await self.db_pool.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) @@ -613,7 +615,7 @@ def _get_media_mxcs_in_room_txn(txn): return local_media_mxcs, remote_media_mxcs - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) @@ -630,7 +632,7 @@ def _quarantine_media_in_room_txn(txn): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -714,7 +716,7 @@ def _quarantine_media_by_id_txn(txn): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) @@ -730,7 +732,7 @@ def _quarantine_media_by_user_txn(txn): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db.runInteraction( + return self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -848,7 +850,7 @@ def get_all_new_public_rooms(txn): return updates, upto_token, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_public_rooms", get_all_new_public_rooms ) @@ -857,21 +859,21 @@ class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.config = hs.config - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, self._remove_tombstoned_rooms_from_directory, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.ADD_ROOMS_ROOM_VERSION_COLUMN, self._background_add_rooms_room_version_column, ) @@ -900,7 +902,7 @@ def _background_insert_retention_txn(txn): (last_room, batch_size), ) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return True @@ -912,7 +914,7 @@ def _background_insert_retention_txn(txn): ev = db_to_json(row["json"]) retention_policy = ev["content"] - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -925,7 +927,7 @@ def _background_insert_retention_txn(txn): logger.info("Inserted %d rows into room_retention", len(rows)) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) @@ -934,12 +936,12 @@ def _background_insert_retention_txn(txn): else: return False - end = await self.db.runInteraction( + end = await self.db_pool.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) if end: - await self.db.updates._end_background_update("insert_room_retention") + await self.db_pool.updates._end_background_update("insert_room_retention") return batch_size @@ -983,7 +985,7 @@ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): # mainly for paranoia as much badness would happen if we don't # insert the row and then try and get the room version for the # room. - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="rooms", keyvalues={"room_id": room_id}, @@ -992,19 +994,19 @@ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): ) new_last_room_id = room_id - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} ) return False - end = await self.db.runInteraction( + end = await self.db_pool.runInteraction( "_background_add_rooms_room_version_column", _background_add_rooms_room_version_column_txn, ) if end: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.ADD_ROOMS_ROOM_VERSION_COLUMN ) @@ -1038,12 +1040,12 @@ def _get_rooms(txn): return [row[0] for row in txn] - rooms = await self.db.runInteraction( + rooms = await self.db_pool.runInteraction( "get_tombstoned_directory_rooms", _get_rooms ) if not rooms: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE ) return 0 @@ -1052,7 +1054,7 @@ def _get_rooms(txn): logger.info("Removing tombstoned room %s from the directory", room_id) await self.set_room_is_public(room_id, False) - await self.db.updates._background_update_progress( + await self.db_pool.updates._background_update_progress( self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} ) @@ -1068,7 +1070,7 @@ def set_room_is_public(self, room_id, is_public): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -1079,7 +1081,7 @@ async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): Called when we join a room over federation, and overwrites any room version currently in the table. """ - await self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", keyvalues={"room_id": room_id}, @@ -1111,7 +1113,7 @@ async def store_room( try: def store_room_txn(txn, next_id): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "rooms", { @@ -1122,7 +1124,7 @@ def store_room_txn(txn, next_id): }, ) if is_public: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -1133,7 +1135,9 @@ def store_room_txn(txn, next_id): ) with self._public_room_id_gen.get_next() as next_id: - await self.db.runInteraction("store_room_txn", store_room_txn, next_id) + await self.db_pool.runInteraction( + "store_room_txn", store_room_txn, next_id + ) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1143,7 +1147,7 @@ async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersi When we receive an invite over federation, store the version of the room if we don't already know the room version. """ - await self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="maybe_store_room_on_invite", table="rooms", keyvalues={"room_id": room_id}, @@ -1160,14 +1164,14 @@ async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersi async def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self.db.simple_select_list_txn( + entries = self.db_pool.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -1185,7 +1189,7 @@ def set_room_is_public_txn(txn, next_id): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -1198,7 +1202,7 @@ def set_room_is_public_txn(txn, next_id): ) with self._public_room_id_gen.get_next() as next_id: - await self.db.runInteraction( + await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() @@ -1224,7 +1228,7 @@ async def set_room_is_public_appservice( def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -1237,7 +1241,7 @@ def set_room_is_public_appservice_txn(txn, next_id): # We've already inserted, nothing to do. return else: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -1247,7 +1251,7 @@ def set_room_is_public_appservice_txn(txn, next_id): }, ) - entries = self.db.simple_select_list_txn( + entries = self.db_pool.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -1265,7 +1269,7 @@ def set_room_is_public_appservice_txn(txn, next_id): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -1278,7 +1282,7 @@ def set_room_is_public_appservice_txn(txn, next_id): ) with self._public_room_id_gen.get_next() as next_id: - await self.db.runInteraction( + await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -1295,13 +1299,13 @@ def f(txn): row = txn.fetchone() return row[0] or 0 - return self.db.runInteraction("get_rooms", f) + return self.db_pool.runInteraction("get_rooms", f) def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self.db.simple_insert( + return self.db_pool.simple_insert( table="event_reports", values={ "id": next_id, @@ -1325,14 +1329,14 @@ async def block_room(self, room_id: str, user_id: str) -> None: room_id: Room to block user_id: Who blocked it """ - await self.db.simple_upsert( + await self.db_pool.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - await self.db.runInteraction( + await self.db_pool.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, @@ -1388,7 +1392,7 @@ def get_rooms_for_retention_period_in_range_txn(txn): txn.execute(sql, args) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) rooms_dict = {} for row in rows: @@ -1404,7 +1408,7 @@ def get_rooms_for_retention_period_in_range_txn(txn): txn.execute(sql) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) # If a room isn't already in the dict (i.e. it doesn't have a retention # policy in its state), add it with a null policy. @@ -1417,7 +1421,7 @@ def get_rooms_for_retention_period_in_range_txn(txn): return rooms_dict - rooms = await self.db.runInteraction( + rooms = await self.db_pool.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/databases/main/roommember.py similarity index 80% rename from synapse/storage/data_stores/main/roommember.py rename to synapse/storage/databases/main/roommember.py index a92e401e8864..b2fcfc9bfe83 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Iterable, List, Set +from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import ( @@ -28,8 +30,8 @@ db_to_json, make_in_list_sql_clause, ) -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -40,9 +42,12 @@ from synapse.types import Collection, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.state import _StateCacheEntry + logger = logging.getLogger(__name__) @@ -51,7 +56,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the @@ -116,7 +121,7 @@ def _transact(txn): txn.execute(query) return list(txn)[0][0] - count = yield self.db.runInteraction("get_known_servers", _transact) + count = yield self.db_pool.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -128,7 +133,7 @@ def _check_safe_current_state_events_membership_updated_txn(self, txn): membership column is up to date """ - pending_update = self.db.simple_select_one_txn( + pending_update = self.db_pool.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -144,18 +149,18 @@ def _check_safe_current_state_events_membership_updated_txn(self, txn): 15.0, run_as_background_process, "_check_safe_current_state_events_membership_updated", - self.db.runInteraction, + self.db_pool.runInteraction, "_check_safe_current_state_events_membership_updated", self._check_safe_current_state_events_membership_updated_txn, ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): - return self.db.runInteraction( + def get_users_in_room(self, room_id: str): + return self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn(self, txn, room_id): + def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -178,11 +183,11 @@ def get_users_in_room_txn(self, txn, room_id): return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id): + def get_room_summary(self, room_id: str): """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: - room_id (str): The room ID to query + room_id: The room ID to query Returns: Deferred[dict[str, MemberSummary]: dict of membership states, pointing to a MemberSummary named tuple. @@ -259,80 +264,61 @@ def _get_room_summary_txn(txn): return res - return self.db.runInteraction("get_room_summary", _get_room_summary_txn) - - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} + return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) @cached() - def get_invited_rooms_for_local_user(self, user_id): - """ Get all the rooms the *local* user is invited to + def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: + """Get all the rooms the *local* user is invited to. Args: - user_id (str): The user ID. + user_id: The user ID. + Returns: - A deferred list of RoomsForUser. + A awaitable list of RoomsForUser. """ return self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) - @defer.inlineCallbacks - def get_invite_for_local_user_in_room(self, user_id, room_id): - """Gets the invite for the given *local* user and room + async def get_invite_for_local_user_in_room( + self, user_id: str, room_id: str + ) -> Optional[RoomsForUser]: + """Gets the invite for the given *local* user and room. Args: - user_id (str) - room_id (str) + user_id: The user ID to find the invite of. + room_id: The room to user was invited to. Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. + Either a RoomsForUser or None if no invite was found. """ - invites = yield self.get_invited_rooms_for_local_user(user_id) + invites = await self.get_invited_rooms_for_local_user(user_id) for invite in invites: if invite.room_id == room_id: return invite return None - @defer.inlineCallbacks - def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this *local* user where the membership for this user + async def get_rooms_for_local_user_where_membership_is( + self, user_id: str, membership_list: List[str] + ) -> Optional[List[RoomsForUser]]: + """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. Filters out forgotten rooms. Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. + user_id: The user ID. + membership_list: A list of synapse.api.constants.Membership + values which the user must be in. Returns: - Deferred[list[RoomsForUser]] + The RoomsForUser that the user matches the membership types. """ if not membership_list: - return defer.succeed(None) + return None - rooms = yield self.db.runInteraction( + rooms = await self.db_pool.runInteraction( "get_rooms_for_local_user_where_membership_is", self._get_rooms_for_local_user_where_membership_is_txn, user_id, @@ -340,12 +326,12 @@ def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list) ) # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) return [room for room in rooms if room.room_id not in forgotten_rooms] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): + self, txn, user_id: str, membership_list: List[str] + ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): raise Exception( @@ -369,32 +355,32 @@ def _get_rooms_for_local_user_where_membership_is_txn( ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] + results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)] return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): + def get_rooms_for_user_with_stream_ordering(self, user_id: str): """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. Args: - user_id (str) + user_id Returns: Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns the rooms the user is in currently, along with the stream ordering of the most recent join for that user and room. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", self._get_rooms_for_user_with_stream_ordering_txn, user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -453,42 +439,44 @@ def _get_users_server_still_shares_room_with_txn(txn): return {row[0] for row in txn} - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_server_still_shares_room_with", _get_users_server_still_shares_room_with_txn, ) - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): + async def get_rooms_for_user(self, user_id: str, on_invalidate=None): """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( + rooms = await self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate ) return frozenset(r.room_id for r in rooms) - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): + @cached(max_entries=500000, cache_context=True, iterable=True) + async def get_users_who_share_room_with_user( + self, user_id: str, cache_context: _CacheContext + ) -> Set[str]: """Returns the set of users who share a room with `user_id` """ - room_ids = yield self.get_rooms_for_user( + room_ids = await self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate ) user_who_share_room = set() for room_id in room_ids: - user_ids = yield self.get_users_in_room( + user_ids = await self.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) user_who_share_room.update(user_ids) return user_who_share_room - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): + async def get_joined_users_from_context( + self, event: EventBase, context: EventContext + ): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -497,14 +485,12 @@ def get_joined_users_from_context(self, event, context): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._get_joined_users_from_context( + current_state_ids = await context.get_current_state_ids() + return await self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) - return result - @defer.inlineCallbacks - def get_joined_users_from_state(self, room_id, state_entry): + async def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -514,16 +500,12 @@ def get_joined_users_from_state(self, room_id, state_entry): state_group = object() with Measure(self._clock, "get_joined_users_from_state"): - return ( - yield self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) + return await self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry ) - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( + @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) + async def _get_joined_users_from_context( self, room_id, state_group, @@ -535,7 +517,6 @@ def _get_joined_users_from_context( # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None users_in_room = {} @@ -588,7 +569,7 @@ def _get_joined_users_from_context( missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = yield self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) users_in_room.update((row for row in event_to_memberships.values() if row)) @@ -612,19 +593,19 @@ def _get_joined_profile_from_event_id(self, event_id): list_name="event_ids", inlineCallbacks=True, ) - def _get_joined_profiles_from_event_ids(self, event_ids): + def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. Args: - event_ids (Iterable[str]): The member event IDs to lookup + event_ids: The member event IDs to lookup Returns: Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -644,8 +625,8 @@ def _get_joined_profiles_from_event_ids(self, event_ids): for row in rows } - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): + @cached(max_entries=10000) + async def is_host_joined(self, room_id: str, host: str) -> bool: if "%" in host or "_" in host: raise Exception("Invalid host name") @@ -664,47 +645,9 @@ def is_host_joined(self, room_id, host): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) + rows = await self.db_pool.execute( + "is_host_joined", None, sql, room_id, like_clause + ) if not rows: return False @@ -716,8 +659,7 @@ def was_host_joined(self, room_id, host): return True - @defer.inlineCallbacks - def get_joined_hosts(self, room_id, state_entry): + async def get_joined_hosts(self, room_id: str, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -727,32 +669,28 @@ def get_joined_hosts(self, room_id, state_entry): state_group = object() with Measure(self._clock, "get_joined_hosts"): - return ( - yield self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) + return await self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry ) - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + @cached(num_args=2, max_entries=10000, iterable=True) + async def _get_joined_hosts( + self, room_id, state_group, current_state_ids, state_entry + ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = yield self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts + cache = await self._get_joined_hosts_cache(room_id) + return await cache.get_destinations(state_entry) @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): + def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache(self, room_id) - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): + @cached(num_args=2) + async def did_forget(self, user_id: str, room_id: str) -> bool: """Returns whether user_id has elected to discard history for room_id. Returns False if they have since re-joined.""" @@ -774,15 +712,15 @@ def f(txn): rows = txn.fetchall() return rows[0][0] - count = yield self.db.runInteraction("did_forget_membership", f) + count = await self.db_pool.runInteraction("did_forget_membership", f) return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id): + def get_forgotten_rooms_for_user(self, user_id: str): """Gets all rooms the user has forgotten. Args: - user_id (str) + user_id Returns: Deferred[set[str]] @@ -811,22 +749,21 @@ def _get_forgotten_rooms_for_user_txn(txn): txn.execute(sql, (user_id,)) return {row[0] for row in txn if row[1] == 0} - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. Args: - user_id (str) + user_id: The user ID to get the rooms of. Returns: - Deferred[set[str]]: Set of room IDs. + Set of room IDs. """ - room_ids = yield self.db.simple_select_onecol( + room_ids = await self.db_pool.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -841,7 +778,7 @@ def get_membership_from_event_ids( """Get user_id and membership of a set of event IDs. """ - return self.db.simple_select_many_batch( + return self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -877,23 +814,23 @@ def _is_local_host_in_room_ignoring_users_txn(txn): return bool(txn.fetchone()) - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "is_local_host_in_room_ignoring_users", _is_local_host_in_room_ignoring_users_txn, ) class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "room_membership_forgotten_idx", index_name="room_memberships_user_room_forgotten", table="room_memberships", @@ -901,8 +838,7 @@ def __init__(self, database: Database, db_conn, hs): where_clause="forgotten = 1", ) - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): + async def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( "target_min_stream_id_inclusive", self._min_stream_order_on_start ) @@ -926,7 +862,7 @@ def add_membership_profile_txn(txn): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return 0 @@ -961,25 +897,24 @@ def add_membership_profile_txn(txn): "max_stream_id_exclusive": min_stream_id, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress ) return len(rows) - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) if not result: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( _MEMBERSHIP_PROFILE_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): + async def _background_current_state_membership(self, progress, batch_size): """Update the new membership column on current_state_events. This works by iterating over all rooms in alphebetical order. @@ -1013,7 +948,7 @@ def _background_current_state_membership_txn(txn, last_processed_room): last_processed_room = next_room - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, {"last_processed_room": last_processed_room}, @@ -1025,14 +960,14 @@ def _background_current_state_membership_txn(txn, last_processed_room): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.db.runInteraction( + row_count, finished = await self.db_pool.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, ) if finished: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME ) @@ -1040,10 +975,10 @@ def _background_current_state_membership_txn(txn, last_processed_room): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id, room_id): + def forget(self, user_id: str, room_id: str): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -1064,7 +999,7 @@ def f(txn): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.db.runInteraction("forget_membership", f) + return self.db_pool.runInteraction("forget_membership", f) class _JoinedHostsCache(object): @@ -1084,17 +1019,19 @@ def __init__(self, store, room_id): self._len = 0 - @defer.inlineCallbacks - def get_destinations(self, state_entry): + async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: """Get set of destinations for a state entry Args: - state_entry(synapse.state._StateCacheEntry) + state_entry + + Returns: + The destinations as a set. """ if state_entry.state_group == self.state_group: return frozenset(self.hosts_to_joined_users) - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_entry.state_group == self.state_group: pass elif state_entry.prev_group == self.state_group: @@ -1106,7 +1043,7 @@ def get_destinations(self, state_entry): user_id = state_key known_joins = self.hosts_to_joined_users.setdefault(host, set()) - event = yield self.store.get_event(event_id) + event = await self.store.get_event(event_id) if event.membership == Membership.JOIN: known_joins.add(user_id) else: @@ -1115,7 +1052,7 @@ def get_destinations(self, state_entry): if not known_joins: self.hosts_to_joined_users.pop(host, None) else: - joined_users = yield self.store.get_joined_users_from_state( + joined_users = await self.store.get_joined_users_from_state( self.room_id, state_entry ) diff --git a/synapse/storage/data_stores/main/schema/delta/12/v12.sql b/synapse/storage/databases/main/schema/delta/12/v12.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/12/v12.sql rename to synapse/storage/databases/main/schema/delta/12/v12.sql diff --git a/synapse/storage/data_stores/main/schema/delta/13/v13.sql b/synapse/storage/databases/main/schema/delta/13/v13.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/13/v13.sql rename to synapse/storage/databases/main/schema/delta/13/v13.sql diff --git a/synapse/storage/data_stores/main/schema/delta/14/v14.sql b/synapse/storage/databases/main/schema/delta/14/v14.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/14/v14.sql rename to synapse/storage/databases/main/schema/delta/14/v14.sql diff --git a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql rename to synapse/storage/databases/main/schema/delta/15/appservice_txns.sql diff --git a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql rename to synapse/storage/databases/main/schema/delta/15/presence_indices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/15/v15.sql b/synapse/storage/databases/main/schema/delta/15/v15.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/15/v15.sql rename to synapse/storage/databases/main/schema/delta/15/v15.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql rename to synapse/storage/databases/main/schema/delta/16/events_order_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql rename to synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql rename to synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql rename to synapse/storage/databases/main/schema/delta/16/room_alias_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql rename to synapse/storage/databases/main/schema/delta/16/unique_constraints.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/users.sql b/synapse/storage/databases/main/schema/delta/16/users.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/16/users.sql rename to synapse/storage/databases/main/schema/delta/16/users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql rename to synapse/storage/databases/main/schema/delta/17/drop_indexes.sql diff --git a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql b/synapse/storage/databases/main/schema/delta/17/server_keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/17/server_keys.sql rename to synapse/storage/databases/main/schema/delta/17/server_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql rename to synapse/storage/databases/main/schema/delta/17/user_threepids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql rename to synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql diff --git a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql b/synapse/storage/databases/main/schema/delta/19/event_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/19/event_index.sql rename to synapse/storage/databases/main/schema/delta/19/event_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql b/synapse/storage/databases/main/schema/delta/20/dummy.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/20/dummy.sql rename to synapse/storage/databases/main/schema/delta/20/dummy.sql diff --git a/synapse/storage/data_stores/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/20/pushers.py rename to synapse/storage/databases/main/schema/delta/20/pushers.py diff --git a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql rename to synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql b/synapse/storage/databases/main/schema/delta/21/receipts.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/21/receipts.sql rename to synapse/storage/databases/main/schema/delta/21/receipts.sql diff --git a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql rename to synapse/storage/databases/main/schema/delta/22/receipts_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql rename to synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql diff --git a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql rename to synapse/storage/databases/main/schema/delta/24/stats_reporting.sql diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/25/fts.py rename to synapse/storage/databases/main/schema/delta/25/fts.py diff --git a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql b/synapse/storage/databases/main/schema/delta/25/guest_access.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/25/guest_access.sql rename to synapse/storage/databases/main/schema/delta/25/guest_access.sql diff --git a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql rename to synapse/storage/databases/main/schema/delta/25/history_visibility.sql diff --git a/synapse/storage/data_stores/main/schema/delta/25/tags.sql b/synapse/storage/databases/main/schema/delta/25/tags.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/25/tags.sql rename to synapse/storage/databases/main/schema/delta/25/tags.sql diff --git a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql b/synapse/storage/databases/main/schema/delta/26/account_data.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/26/account_data.sql rename to synapse/storage/databases/main/schema/delta/26/account_data.sql diff --git a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql b/synapse/storage/databases/main/schema/delta/27/account_data.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/27/account_data.sql rename to synapse/storage/databases/main/schema/delta/27/account_data.sql diff --git a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql rename to synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/27/ts.py rename to synapse/storage/databases/main/schema/delta/27/ts.py diff --git a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql rename to synapse/storage/databases/main/schema/delta/28/event_push_actions.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql rename to synapse/storage/databases/main/schema/delta/28/events_room_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql rename to synapse/storage/databases/main/schema/delta/28/public_roms_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql rename to synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql rename to synapse/storage/databases/main/schema/delta/28/upgrade_times.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql rename to synapse/storage/databases/main/schema/delta/28/users_is_guest.sql diff --git a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql b/synapse/storage/databases/main/schema/delta/29/push_actions.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/29/push_actions.sql rename to synapse/storage/databases/main/schema/delta/29/push_actions.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql rename to synapse/storage/databases/main/schema/delta/30/alias_creator.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/as_users.py rename to synapse/storage/databases/main/schema/delta/30/as_users.py diff --git a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql rename to synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql rename to synapse/storage/databases/main/schema/delta/30/presence_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql rename to synapse/storage/databases/main/schema/delta/30/public_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql rename to synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql rename to synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/invites.sql b/synapse/storage/databases/main/schema/delta/31/invites.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/31/invites.sql rename to synapse/storage/databases/main/schema/delta/31/invites.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql rename to synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/31/pushers.py rename to synapse/storage/databases/main/schema/delta/31/pushers.py diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql rename to synapse/storage/databases/main/schema/delta/31/pushers_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/31/search_update.py rename to synapse/storage/databases/main/schema/delta/31/search_update.py diff --git a/synapse/storage/data_stores/main/schema/delta/32/events.sql b/synapse/storage/databases/main/schema/delta/32/events.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/32/events.sql rename to synapse/storage/databases/main/schema/delta/32/events.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/openid.sql b/synapse/storage/databases/main/schema/delta/32/openid.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/32/openid.sql rename to synapse/storage/databases/main/schema/delta/32/openid.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql rename to synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql rename to synapse/storage/databases/main/schema/delta/32/remove_indices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/reports.sql b/synapse/storage/databases/main/schema/delta/32/reports.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/32/reports.sql rename to synapse/storage/databases/main/schema/delta/32/reports.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql rename to synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices.sql b/synapse/storage/databases/main/schema/delta/33/devices.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/33/devices.sql rename to synapse/storage/databases/main/schema/delta/33/devices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql rename to synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql rename to synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/33/event_fields.py rename to synapse/storage/databases/main/schema/delta/33/event_fields.py diff --git a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py rename to synapse/storage/databases/main/schema/delta/33/remote_media_ts.py diff --git a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql rename to synapse/storage/databases/main/schema/delta/33/user_ips_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql rename to synapse/storage/databases/main/schema/delta/34/appservice_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py b/synapse/storage/databases/main/schema/delta/34/cache_stream.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/34/cache_stream.py rename to synapse/storage/databases/main/schema/delta/34/cache_stream.py diff --git a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql rename to synapse/storage/databases/main/schema/delta/34/device_inbox.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql rename to synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py rename to synapse/storage/databases/main/schema/delta/34/received_txn_purge.py diff --git a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql b/synapse/storage/databases/main/schema/delta/35/contains_url.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/contains_url.sql rename to synapse/storage/databases/main/schema/delta/35/contains_url.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql rename to synapse/storage/databases/main/schema/delta/35/device_outbox.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql rename to synapse/storage/databases/main/schema/delta/35/device_stream_id.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql rename to synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql rename to synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql rename to synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql diff --git a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql rename to synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py rename to synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py diff --git a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql rename to synapse/storage/databases/main/schema/delta/37/user_threepids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql rename to synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql rename to synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql rename to synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql rename to synapse/storage/databases/main/schema/delta/39/event_push_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql rename to synapse/storage/databases/main/schema/delta/39/federation_out_position.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql rename to synapse/storage/databases/main/schema/delta/39/membership_profile.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql rename to synapse/storage/databases/main/schema/delta/40/current_state_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql rename to synapse/storage/databases/main/schema/delta/40/device_inbox.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql rename to synapse/storage/databases/main/schema/delta/40/device_list_streams.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql rename to synapse/storage/databases/main/schema/delta/40/event_push_summary.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql b/synapse/storage/databases/main/schema/delta/40/pushers.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/40/pushers.sql rename to synapse/storage/databases/main/schema/delta/40/pushers.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql rename to synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql rename to synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql rename to synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql rename to synapse/storage/databases/main/schema/delta/41/ratelimit.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql rename to synapse/storage/databases/main/schema/delta/42/current_state_delta.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql rename to synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql rename to synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py b/synapse/storage/databases/main/schema/delta/42/user_dir.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/42/user_dir.py rename to synapse/storage/databases/main/schema/delta/42/user_dir.py diff --git a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql rename to synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql rename to synapse/storage/databases/main/schema/delta/43/quarantine_media.sql diff --git a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql b/synapse/storage/databases/main/schema/delta/43/url_cache.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/43/url_cache.sql rename to synapse/storage/databases/main/schema/delta/43/url_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql b/synapse/storage/databases/main/schema/delta/43/user_share.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/43/user_share.sql rename to synapse/storage/databases/main/schema/delta/43/user_share.sql diff --git a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql rename to synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql b/synapse/storage/databases/main/schema/delta/45/group_server.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/45/group_server.sql rename to synapse/storage/databases/main/schema/delta/45/group_server.sql diff --git a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql rename to synapse/storage/databases/main/schema/delta/45/profile_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql rename to synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql rename to synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql b/synapse/storage/databases/main/schema/delta/46/group_server.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/46/group_server.sql rename to synapse/storage/databases/main/schema/delta/46/group_server.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql rename to synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql rename to synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql rename to synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql rename to synapse/storage/databases/main/schema/delta/47/last_access_media.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql rename to synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql rename to synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql rename to synapse/storage/databases/main/schema/delta/48/add_user_consent.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql rename to synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql rename to synapse/storage/databases/main/schema/delta/48/deactivated_users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py rename to synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py diff --git a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql rename to synapse/storage/databases/main/schema/delta/48/groups_joinable.sql diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql rename to synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql rename to synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql rename to synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql rename to synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql rename to synapse/storage/databases/main/schema/delta/50/erasure_store.sql diff --git a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py rename to synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py diff --git a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql rename to synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql rename to synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql rename to synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql rename to synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql rename to synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql rename to synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql rename to synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql rename to synapse/storage/databases/main/schema/delta/53/event_format_version.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql rename to synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql rename to synapse/storage/databases/main/schema/delta/53/user_ips_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql b/synapse/storage/databases/main/schema/delta/53/user_share.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/user_share.sql rename to synapse/storage/databases/main/schema/delta/53/user_share.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql rename to synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql rename to synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql rename to synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql rename to synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql rename to synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql rename to synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql rename to synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/relations.sql b/synapse/storage/databases/main/schema/delta/54/relations.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/relations.sql rename to synapse/storage/databases/main/schema/delta/54/relations.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats.sql b/synapse/storage/databases/main/schema/delta/54/stats.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/stats.sql rename to synapse/storage/databases/main/schema/delta/54/stats.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql b/synapse/storage/databases/main/schema/delta/54/stats2.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/54/stats2.sql rename to synapse/storage/databases/main/schema/delta/54/stats2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql rename to synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql diff --git a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql rename to synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql diff --git a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql rename to synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql rename to synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql rename to synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql rename to synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql rename to synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql rename to synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres rename to synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql rename to synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql rename to synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql rename to synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql rename to synapse/storage/databases/main/schema/delta/56/event_expiry.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/event_labels.sql rename to synapse/storage/databases/main/schema/delta/56/event_labels.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql rename to synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql rename to synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql rename to synapse/storage/databases/main/schema/delta/56/hidden_devices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite rename to synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql rename to synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql rename to synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql rename to synapse/storage/databases/main/schema/delta/56/redaction_censor.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql rename to synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres rename to synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql rename to synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql rename to synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql rename to synapse/storage/databases/main/schema/delta/56/room_key_etag.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql rename to synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/databases/main/schema/delta/56/room_retention.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/room_retention.sql rename to synapse/storage/databases/main/schema/delta/56/room_retention.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql rename to synapse/storage/databases/main/schema/delta/56/signing_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql rename to synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql rename to synapse/storage/databases/main/schema/delta/56/stats_separated.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py rename to synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql rename to synapse/storage/databases/main/schema/delta/56/user_external_ids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql rename to synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql rename to synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql rename to synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py rename to synapse/storage/databases/main/schema/delta/57/local_current_membership.py diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql rename to synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql rename to synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres rename to synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite rename to synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres rename to synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite rename to synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql rename to synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql rename to synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres rename to synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py rename to synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres rename to synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite rename to synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql rename to synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql rename to synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py similarity index 94% rename from synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py rename to synapse/storage/databases/main/schema/delta/58/11user_id_seq.py index 2011f6bcebc2..4310ec12ce1a 100644 --- a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py +++ b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py @@ -16,7 +16,7 @@ Adds a postgres SEQUENCE for generating guest user IDs. """ -from synapse.storage.data_stores.main.registration import ( +from synapse.storage.databases.main.registration import ( find_max_generated_user_id_localpart, ) from synapse.storage.engines import PostgresEngine diff --git a/synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/58/12room_stats.sql rename to synapse/storage/databases/main/schema/delta/58/12room_stats.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql rename to synapse/storage/databases/main/schema/full_schemas/16/application_services.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql rename to synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql rename to synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql b/synapse/storage/databases/main/schema/full_schemas/16/im.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/im.sql rename to synapse/storage/databases/main/schema/full_schemas/16/im.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql rename to synapse/storage/databases/main/schema/full_schemas/16/keys.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql rename to synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql rename to synapse/storage/databases/main/schema/full_schemas/16/presence.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql rename to synapse/storage/databases/main/schema/full_schemas/16/profiles.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql b/synapse/storage/databases/main/schema/full_schemas/16/push.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/push.sql rename to synapse/storage/databases/main/schema/full_schemas/16/push.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql rename to synapse/storage/databases/main/schema/full_schemas/16/redactions.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql rename to synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql b/synapse/storage/databases/main/schema/full_schemas/16/state.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/state.sql rename to synapse/storage/databases/main/schema/full_schemas/16/state.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql rename to synapse/storage/databases/main/schema/full_schemas/16/transactions.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql b/synapse/storage/databases/main/schema/full_schemas/16/users.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/16/users.sql rename to synapse/storage/databases/main/schema/full_schemas/16/users.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres rename to synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite rename to synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql rename to synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/databases/main/schema/full_schemas/README.md similarity index 100% rename from synapse/storage/data_stores/main/schema/full_schemas/README.md rename to synapse/storage/databases/main/schema/full_schemas/README.md diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/databases/main/search.py similarity index 88% rename from synapse/storage/data_stores/main/search.py rename to synapse/storage/databases/main/search.py index d52228297c28..7f8d1880e57e 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -16,13 +16,12 @@ import logging import re from collections import namedtuple - -from twisted.internet import defer +from typing import List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -88,16 +87,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -106,16 +105,15 @@ def __init__(self, database: Database, db_conn, hs): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.db.updates.register_noop_background_update( + self.db_pool.updates.register_noop_background_update( self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) - @defer.inlineCallbacks - def _background_reindex_search(self, progress, batch_size): + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -140,7 +138,7 @@ def reindex_search_txn(txn): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return 0 @@ -200,23 +198,24 @@ def reindex_search_txn(txn): "rows_inserted": rows_inserted + len(event_search_rows), } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) return len(event_search_rows) - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) if not result: - yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + await self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_UPDATE_NAME + ) return result - @defer.inlineCallbacks - def _background_reindex_gin_search(self, progress, batch_size): + async def _background_reindex_gin_search(self, progress, batch_size): """This handles old synapses which used GIST indexes, if any; converting them back to be GIN as per the actual schema. """ @@ -253,15 +252,14 @@ def create_index(conn): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.db.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME ) return 1 - @defer.inlineCallbacks - def _background_reindex_search_order(self, progress, batch_size): + async def _background_reindex_search_order(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -286,14 +284,14 @@ def create_index(conn): ) conn.set_session(autocommit=False) - yield self.db.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.db.runInteraction( + await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -323,18 +321,18 @@ def reindex_search_txn(txn): "have_added_indexes": True, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) return len(rows), True - num_rows, finished = yield self.db.runInteraction( + num_rows, finished = await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) if not finished: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_ORDER_UPDATE_NAME ) @@ -342,11 +340,10 @@ def reindex_search_txn(txn): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SearchStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def search_msgs(self, room_ids, search_term, keys): + async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. Args: @@ -423,15 +420,15 @@ def search_msgs(self, room_ids, search_term, keys): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self.db.execute( - "search_msgs", self.db.cursor_to_dict, sql, *args + results = await self.db_pool.execute( + "search_msgs", self.db_pool.cursor_to_dict, sql, *args ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -440,12 +437,12 @@ def search_msgs(self, room_ids, search_term, keys): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db.execute( - "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args + count_results = await self.db_pool.execute( + "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -460,19 +457,25 @@ def search_msgs(self, room_ids, search_term, keys): "count": count, } - @defer.inlineCallbacks - def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): + async def search_rooms( + self, + room_ids: List[str], + search_term: str, + keys: List[str], + limit, + pagination_token: Optional[str] = None, + ) -> List[dict]: """Performs a full text search over events with given keys. Args: - room_id (list): The room_ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - pagination_token (str): A pagination token previously returned + room_ids: The room_ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", + "content.name", "content.topic" + pagination_token: A pagination token previously returned Returns: - list of dicts + Each match as a dictionary. """ clauses = [] @@ -575,15 +578,15 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None args.append(limit) - results = yield self.db.execute( - "search_rooms", self.db.cursor_to_dict, sql, *args + results = await self.db_pool.execute( + "search_rooms", self.db_pool.cursor_to_dict, sql, *args ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -592,12 +595,12 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db.execute( - "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args + count_results = await self.db_pool.execute( + "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -682,7 +685,7 @@ def f(txn): return highlight_words - return self.db.runInteraction("_find_highlights", f) + return self.db_pool.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/databases/main/signatures.py similarity index 89% rename from synapse/storage/data_stores/main/signatures.py rename to synapse/storage/databases/main/signatures.py index 36244d9f5da7..be191dd8708c 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -15,8 +15,6 @@ from unpaddedbase64 import encode_base64 -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList @@ -38,11 +36,10 @@ def f(txn): for event_id in event_ids } - return self.db.runInteraction("get_event_reference_hashes", f) + return self.db_pool.runInteraction("get_event_reference_hashes", f) - @defer.inlineCallbacks - def add_event_hashes(self, event_ids): - hashes = yield self.get_event_reference_hashes(event_ids) + async def add_event_hashes(self, event_ids): + hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/databases/main/state.py similarity index 93% rename from synapse/storage/data_stores/main/state.py rename to synapse/storage/databases/main/state.py index a36069940829..96e0378e5068 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -23,9 +23,9 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -54,7 +54,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -93,7 +93,7 @@ async def get_room_version_id(self, room_id: str) -> str: # We really should have an entry in the rooms table for every room we # care about, but let's be a bit paranoid (at least while the background # update is happening) to avoid breaking existing rooms. - version = await self.db.simple_select_one_onecol( + version = await self.db_pool.simple_select_one_onecol( table="rooms", keyvalues={"room_id": room_id}, retcol="room_version", @@ -184,7 +184,7 @@ def _get_current_state_ids_txn(txn): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_current_state_ids", _get_current_state_ids_txn ) @@ -231,7 +231,7 @@ def _get_filtered_current_state_ids_txn(txn): return results - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @@ -261,7 +261,7 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -278,7 +278,7 @@ def _get_state_group_for_event(self, event_id): def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -301,7 +301,7 @@ async def get_referenced_state_groups( The subset of state groups that are referenced. """ - rows = await self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -319,25 +319,25 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", table="current_state_events", columns=["state_key"], where_clause="type='m.room.member'", ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, index_name="event_to_state_groups_sg_index", table="event_to_state_groups", columns=["state_group"], ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, ) @@ -429,7 +429,7 @@ def _background_remove_left_rooms_txn(txn): # potentially stale, since there may have been a period where the # server didn't share a room with the remote user and therefore may # have missed any device updates. - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="current_state_events", column="room_id", @@ -441,7 +441,7 @@ def _background_remove_left_rooms_txn(txn): potentially_left_users = {row["state_key"] for row in rows} # Now lets actually delete the rooms from the DB. - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="current_state_events", column="room_id", @@ -449,7 +449,7 @@ def _background_remove_left_rooms_txn(txn): keyvalues={}, ) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="event_forward_extremities", column="room_id", @@ -457,7 +457,7 @@ def _background_remove_left_rooms_txn(txn): keyvalues={}, ) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.DELETE_CURRENT_STATE_UPDATE_NAME, {"last_room_id": room_ids[-1]}, @@ -465,12 +465,12 @@ def _background_remove_left_rooms_txn(txn): return False, potentially_left_users - finished, potentially_left_users = await self.db.runInteraction( + finished, potentially_left_users = await self.db_pool.runInteraction( "_background_remove_left_rooms", _background_remove_left_rooms_txn ) if finished: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.DELETE_CURRENT_STATE_UPDATE_NAME ) @@ -505,5 +505,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py similarity index 95% rename from synapse/storage/data_stores/main/state_deltas.py rename to synapse/storage/databases/main/state_deltas.py index 725e12507f7c..0d963c98ffa9 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -100,14 +100,14 @@ def get_current_state_deltas_txn(txn): ORDER BY stream_id ASC """ txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.db.cursor_to_dict(txn) + return clipped_stream_id, self.db_pool.cursor_to_dict(txn) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self.db.simple_select_one_onecol_txn( + return self.db_pool.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, @@ -115,7 +115,7 @@ def _get_max_stream_id_in_current_state_deltas_txn(self, txn): ) def get_max_stream_id_in_current_state_deltas(self): - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/databases/main/stats.py similarity index 92% rename from synapse/storage/data_stores/main/stats.py rename to synapse/storage/databases/main/stats.py index 40db8f594eaa..802c9019b9f4 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -21,8 +21,8 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership -from synapse.storage.data_stores.main.state_deltas import StateDeltasStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -59,7 +59,7 @@ class StatsStore(StateDeltasStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -69,20 +69,20 @@ def __init__(self, database: Database, db_conn, hs): self.stats_delta_processing_lock = DeferredLock() - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.db.updates.register_noop_background_update("populate_stats_cleanup") - self.db.updates.register_noop_background_update("populate_stats_prepare") + self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") + self.db_pool.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -105,7 +105,9 @@ async def _populate_stats_process_users(self, progress, batch_size): This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - await self.db.updates._end_background_update("populate_stats_process_users") + await self.db_pool.updates._end_background_update( + "populate_stats_process_users" + ) return 1 last_user_id = progress.get("last_user_id", "") @@ -120,22 +122,24 @@ def _get_next_batch(txn): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = await self.db.runInteraction( + users_to_work_on = await self.db_pool.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - await self.db.updates._end_background_update("populate_stats_process_users") + await self.db_pool.updates._end_background_update( + "populate_stats_process_users" + ) return 1 for user_id in users_to_work_on: await self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - await self.db.runInteraction( + await self.db_pool.runInteraction( "populate_stats_process_users", - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) @@ -153,7 +157,9 @@ async def _populate_stats_process_rooms(self, progress, batch_size): Further context: https://github.com/matrix-org/synapse/pull/7977 """ - await self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db_pool.updates._end_background_update( + "populate_stats_process_rooms" + ) return 1 async def _populate_stats_process_rooms_2(self, progress, batch_size): @@ -164,7 +170,7 @@ async def _populate_stats_process_rooms_2(self, progress, batch_size): reasoning. """ if not self.stats_enabled: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_stats_process_rooms_2" ) return 1 @@ -181,13 +187,13 @@ def _get_next_batch(txn): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = await self.db.runInteraction( + rooms_to_work_on = await self.db_pool.runInteraction( "populate_stats_rooms_2_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_stats_process_rooms_2" ) return 1 @@ -196,9 +202,9 @@ def _get_next_batch(txn): await self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - await self.db.runInteraction( + await self.db_pool.runInteraction( "_populate_stats_process_rooms_2", - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, "populate_stats_process_rooms_2", progress, ) @@ -209,7 +215,7 @@ def get_stats_positions(self): """ Returns the stats processor positions. """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -238,7 +244,7 @@ def update_room_state(self, room_id, fields): if field and "\0" in field: fields[col] = None - return self.db.simple_upsert( + return self.db_pool.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -259,7 +265,7 @@ def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): Deferred[list[dict]], where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -280,7 +286,7 @@ def _get_statistics_for_subject_txn( ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self.db.simple_select_list_paginate_txn( + slice_list = self.db_pool.simple_select_list_paginate_txn( txn, table + "_historical", "end_ts", @@ -306,7 +312,7 @@ def get_earliest_token_for_stats(self, stats_type, id): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -342,14 +348,14 @@ def _bulk_update_stats_delta_txn(txn): complete_with_stream_id=stream_id, ) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": stream_id}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) @@ -380,7 +386,7 @@ def update_stats_delta( Does not work with per-slice fields. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -515,17 +521,17 @@ def _upsert_with_additive_relatives_txn( else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self.db.simple_select_one_txn( + current_row = self.db_pool.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self.db.simple_insert_txn(txn, table, merged_dict) + self.db_pool.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self.db.simple_update_one_txn(txn, table, keyvalues, current_row) + self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -612,11 +618,11 @@ def _upsert_copy_from_table_with_additive_relatives_txn( txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self.db.simple_select_one_txn( + src_row = self.db_pool.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.db.simple_select_one_txn( + dest_current_row = self.db_pool.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -632,11 +638,13 @@ def _upsert_copy_from_table_with_additive_relatives_txn( **src_row, **additive_relatives, } - self.db.simple_insert_txn(txn, into_table, merged_dict) + self.db_pool.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.db_pool.simple_update_txn( + txn, into_table, all_dest_keyvalues, src_row + ) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -650,7 +658,7 @@ def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): changes. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, @@ -733,7 +741,7 @@ async def _calculate_and_set_initial_state_for_room( def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="current_state_events", column="type", @@ -789,7 +797,7 @@ def _fetch_current_state_stats(txn): current_state_events_count, users_in_room, pos, - ) = await self.db.runInteraction( + ) = await self.db_pool.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) @@ -863,7 +871,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = await self.db.runInteraction( + joined_rooms, pos = await self.db_pool.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/databases/main/stream.py similarity index 96% rename from synapse/storage/data_stores/main/stream.py rename to synapse/storage/databases/main/stream.py index f1334a6efce7..aaf225894e23 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -45,8 +45,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database, make_in_list_sql_clause +from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -251,7 +251,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StreamWorkerStore, self).__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -265,7 +265,7 @@ def __init__(self, database: Database, db_conn, hs): self._need_to_reset_federation_stream_positions = self._send_federation events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self.db.get_cache_dict( + event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -410,7 +410,7 @@ def f(txn): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) + rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -460,7 +460,7 @@ def f(txn): return rows - rows = yield self.db.runInteraction("get_membership_changes_for_user", f) + rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -519,7 +519,7 @@ def get_recent_event_ids_for_room(self, room_id, limit, end_token): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.db.runInteraction( + rows, token = yield self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -556,7 +556,7 @@ def _f(txn): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.db.runInteraction("get_room_event_before_stream_ordering", _f) + return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. @@ -569,7 +569,7 @@ async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: if room_id is None: return "s%d" % (token,) else: - topo = await self.db.runInteraction( + topo = await self.db_pool.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) @@ -583,7 +583,7 @@ def get_stream_token_for_event(self, event_id): Returns: A deferred "s%d" stream token. """ - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -596,7 +596,7 @@ def get_topological_token_for_event(self, event_id): Returns: A deferred "t%d-%d" topological token. """ - return self.db.simple_select_one( + return self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -620,7 +620,7 @@ def get_max_topological_token(self, room_id, stream_key): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.db.execute( + return self.db_pool.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) @@ -674,7 +674,7 @@ def get_events_around( dict """ - results = yield self.db.runInteraction( + results = yield self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -716,7 +716,7 @@ def _get_events_around_txn( dict """ - results = self.db.simple_select_one_txn( + results = self.db_pool.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -795,7 +795,7 @@ def get_all_new_events_stream_txn(txn): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db.runInteraction( + upper_bound, event_ids = yield self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) @@ -805,12 +805,12 @@ def get_all_new_events_stream_txn(txn): async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: - await self.db.runInteraction( + await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False - return await self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ, "instance_name": self._instance_name}, @@ -819,12 +819,12 @@ async def get_federation_out_pos(self, typ: str) -> int: async def update_federation_out_pos(self, typ, stream_id): if self._need_to_reset_federation_stream_positions: - await self.db.runInteraction( + await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False - return await self.db.simple_update_one( + return await self.db_pool.simple_update_one( table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, @@ -854,7 +854,7 @@ def _reset_federation_positions_txn(self, txn): elif self._instance_name not in configured_instances: return - instances_in_table = self.db.simple_select_onecol_txn( + instances_in_table = self.db_pool.simple_select_onecol_txn( txn, table="federation_stream_position", keyvalues={}, @@ -885,7 +885,7 @@ def _reset_federation_positions_txn(self, txn): txn.execute(sql % (clause,), args) for typ, stream_id in min_positions.items(): - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, @@ -1036,7 +1036,7 @@ def paginate_room_events( if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.db.runInteraction( + rows, token = yield self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/databases/main/tags.py similarity index 74% rename from synapse/storage/data_stores/main/tags.py rename to synapse/storage/databases/main/tags.py index bd7227773aee..e4e0a0c43379 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,14 +15,13 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import Dict, List, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.storage._base import db_to_json -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -30,30 +29,26 @@ class TagsWorkerStore(AccountDataWorkerStore): @cached() - def get_tags_for_user(self, user_id): + async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: """Get all the tags for a user. Args: - user_id(str): The user to get the tags for. + user_id: The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to dicts mapping from - tag strings to tag content. + A mapping from room_id strings to dicts mapping from tag strings to + tag content. """ - deferred = self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - @deferred.addCallback - def tags_by_room(rows): - tags_by_room = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = db_to_json(row["content"]) - return tags_by_room - - return deferred + tags_by_room = {} + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = db_to_json(row["content"]) + return tags_by_room async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -92,7 +87,7 @@ def get_all_updated_tags_txn(txn): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = await self.db.runInteraction( + tag_ids = await self.db_pool.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) @@ -112,7 +107,7 @@ def get_tag_content(txn, tag_ids): batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = await self.db.runInteraction( + tags = await self.db_pool.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], @@ -127,17 +122,19 @@ def get_tag_content(txn, tag_ids): return results, upto_token, limited - @defer.inlineCallbacks - def get_updated_tags(self, user_id, stream_id): + async def get_updated_tags( + self, user_id: str, stream_id: int + ) -> Dict[str, List[str]]: """Get all the tags for the rooms where the tags have changed since the given version Args: user_id(str): The user to get the tags for. stream_id(int): The earliest update to get for the user. + Returns: - A deferred dict mapping from room_id strings to lists of tag - strings for all the rooms that changed since the stream_id token. + A mapping from room_id strings to lists of tag strings for all the + rooms that changed since the stream_id token. """ def get_updated_tags_txn(txn): @@ -155,52 +152,58 @@ def get_updated_tags_txn(txn): if not changed: return {} - room_ids = yield self.db.runInteraction( + room_ids = await self.db_pool.runInteraction( "get_updated_tags", get_updated_tags_txn ) results = {} if room_ids: - tags_by_room = yield self.get_tags_for_user(user_id) + tags_by_room = await self.get_tags_for_user(user_id) for room_id in room_ids: results[room_id] = tags_by_room.get(room_id, {}) return results - def get_tags_for_room(self, user_id, room_id): + async def get_tags_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the tags for the given room + Args: - user_id(str): The user to get tags for - room_id(str): The room to get tags for + user_id: The user to get tags for + room_id: The room to get tags for + Returns: - A deferred list of string tags. + A mapping of tags to tag content. """ - return self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), desc="get_tags_for_room", - ).addCallback( - lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} ) + return {row["tag"]: db_to_json(row["content"]) for row in rows} class TagsStore(TagsWorkerStore): - @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag, content): + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: """Add a tag to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - tag(str): The tag name to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the tag has been added. + The next account data ID. """ content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, @@ -209,18 +212,17 @@ def add_tag_txn(txn, next_id): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db.runInteraction("add_tag", add_tag_txn, next_id) + await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def remove_tag_from_room(self, user_id, room_id, tag): + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: """Remove a tag from a room for a user. + Returns: - A deferred that completes once the tag has been removed + The next account data ID. """ def remove_tag_txn(txn, next_id): @@ -232,21 +234,22 @@ def remove_tag_txn(txn, next_id): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) + await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_revision_txn(self, txn, user_id, room_id, next_id): + def _update_revision_txn( + self, txn, user_id: str, room_id: str, next_id: int + ) -> None: """Update the latest revision of the tags for the given user and room. Args: txn: The database cursor - user_id(str): The ID of the user. - room_id(str): The ID of the room. - next_id(int): The the revision to advance to. + user_id: The ID of the user. + room_id: The ID of the room. + next_id: The the revision to advance to. """ txn.call_after( diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/databases/main/transactions.py similarity index 92% rename from synapse/storage/data_stores/main/transactions.py rename to synapse/storage/databases/main/transactions.py index a9bf457939f1..52668dbdf9cf 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -18,11 +18,9 @@ from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.util.caches.expiringcache import ExpiringCache db_binary_type = memoryview @@ -46,7 +44,7 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) @@ -71,7 +69,7 @@ def get_received_txn_response(self, transaction_id, origin): this transaction or a 2-tuple of (int, dict) """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -79,7 +77,7 @@ def get_received_txn_response(self, transaction_id, origin): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -113,7 +111,7 @@ def set_received_txn_response(self, transaction_id, origin, code, response_dict) response_json (str) """ - return self.db.simple_insert( + return self.db_pool.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -126,8 +124,7 @@ def set_received_txn_response(self, transaction_id, origin, code, response_dict) desc="set_received_txn_response", ) - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): + async def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. Args: @@ -142,7 +139,7 @@ def get_destination_retry_timings(self, destination): if result is not SENTINEL: return result - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, @@ -154,7 +151,7 @@ def get_destination_retry_timings(self, destination): return result def _get_destination_retry_timings(self, txn, destination): - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -181,7 +178,7 @@ def set_destination_retry_timings( """ self._destination_retry_cache.pop(destination, None) - return self.db.runInteraction( + return self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -221,7 +218,7 @@ def _set_destination_retry_timings( # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self.db.simple_select_one_txn( + prev_row = self.db_pool.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -230,7 +227,7 @@ def _set_destination_retry_timings( ) if not prev_row: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="destinations", values={ @@ -241,7 +238,7 @@ def _set_destination_retry_timings( }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, @@ -264,6 +261,6 @@ def _cleanup_transactions(self): def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.db.runInteraction( + return self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py similarity index 93% rename from synapse/storage/data_stores/main/ui_auth.py rename to synapse/storage/databases/main/ui_auth.py index 5f1b919748a6..37276f73f83b 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -81,7 +81,7 @@ async def create_ui_auth_session( session_id = stringutils.random_string(24) try: - await self.db.simple_insert( + await self.db_pool.simple_insert( table="ui_auth_sessions", values={ "session_id": session_id, @@ -97,7 +97,7 @@ async def create_ui_auth_session( return UIAuthSessionData( session_id, clientdict, uri, method, description ) - except self.db.engine.module.IntegrityError: + except self.db_pool.engine.module.IntegrityError: attempts += 1 raise StoreError(500, "Couldn't generate a session ID.") @@ -111,7 +111,7 @@ async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: Raises: StoreError if the session is not found. """ - result = await self.db.simple_select_one( + result = await self.db_pool.simple_select_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("clientdict", "uri", "method", "description"), @@ -140,13 +140,13 @@ async def mark_ui_auth_stage_complete( # Note that we need to allow for the same stage to complete multiple # times here so that registration is idempotent. try: - await self.db.simple_upsert( + await self.db_pool.simple_upsert( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id, "stage_type": stage_type}, values={"result": json.dumps(result)}, desc="mark_ui_auth_stage_complete", ) - except self.db.engine.module.IntegrityError: + except self.db_pool.engine.module.IntegrityError: raise StoreError(400, "Unknown session ID: %s" % (session_id,)) async def get_completed_ui_auth_stages( @@ -162,7 +162,7 @@ async def get_completed_ui_auth_stages( that auth-type. """ results = {} - for row in await self.db.simple_select_list( + for row in await self.db_pool.simple_select_list( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id}, retcols=("stage_type", "result"), @@ -186,7 +186,7 @@ async def set_ui_auth_clientdict( # The clientdict gets stored as JSON. clientdict_json = json.dumps(clientdict) - await self.db.simple_update_one( + await self.db_pool.simple_update_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, updatevalues={"clientdict": clientdict_json}, @@ -206,7 +206,7 @@ async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): Raises: StoreError if the session cannot be found. """ - await self.db.runInteraction( + await self.db_pool.runInteraction( "set_ui_auth_session_data", self._set_ui_auth_session_data_txn, session_id, @@ -216,7 +216,7 @@ async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): # Get the current value. - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, @@ -227,7 +227,7 @@ def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: A serverdict = db_to_json(result["serverdict"]) serverdict[key] = value - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, @@ -247,7 +247,7 @@ async def get_ui_auth_session_data( Raises: StoreError if the session cannot be found. """ - result = await self.db.simple_select_one( + result = await self.db_pool.simple_select_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), @@ -269,7 +269,7 @@ def delete_old_ui_auth_sessions(self, expiration_time: int): This is an epoch time in milliseconds. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, @@ -282,7 +282,7 @@ def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): session_ids = [r[0] for r in txn.fetchall()] # Delete the corresponding completed credentials. - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions_credentials", column="session_id", @@ -291,7 +291,7 @@ def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): ) # Finally, delete the sessions. - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions", column="session_id", diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/databases/main/user_directory.py similarity index 81% rename from synapse/storage/data_stores/main/user_directory.py rename to synapse/storage/databases/main/user_directory.py index 942e51fd3a73..af21fe457adb 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -16,12 +16,10 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.data_stores.main.state import StateFilter -from synapse.storage.data_stores.main.state_deltas import StateDeltasStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.state import StateFilter +from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -38,29 +36,28 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) - @defer.inlineCallbacks - def _populate_user_directory_createtables(self, progress, batch_size): + async def _populate_user_directory_createtables(self, progress, batch_size): # Get all the rooms that we want to process. def _make_staging_area(txn): @@ -85,7 +82,7 @@ def _make_staging_area(txn): """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,43 +97,45 @@ def _make_staging_area(txn): txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.db.runInteraction( + new_pos = await self.get_max_stream_id_in_current_state_deltas() + await self.db_pool.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + await self.db_pool.simple_insert( + TEMP_TABLE + "_position", {"position": new_pos} + ) - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_createtables" ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_cleanup(self, progress, batch_size): + async def _populate_user_directory_cleanup(self, progress, batch_size): """ Update the user directory stream position, then clean up the old tables. """ - position = yield self.db.simple_select_one_onecol( + position = await self.db_pool.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) - yield self.update_user_directory_stream_pos(position) + await self.update_user_directory_stream_pos(position) def _delete_staging_area(txn): txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.db.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) - yield self.db.updates._end_background_update("populate_user_directory_cleanup") + await self.db_pool.updates._end_background_update( + "populate_user_directory_cleanup" + ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_process_rooms(self, progress, batch_size): + async def _populate_user_directory_process_rooms(self, progress, batch_size): """ Args: progress (dict) @@ -147,7 +146,7 @@ def _populate_user_directory_process_rooms(self, progress, batch_size): # If we don't have progress filed, delete everything. if not progress: - yield self.delete_all_from_user_dir() + await self.delete_all_from_user_dir() def _get_next_batch(txn): # Only fetch 250 rooms, so we don't fetch too many at once, even @@ -172,13 +171,13 @@ def _get_next_batch(txn): return rooms_to_work_on - rooms_to_work_on = yield self.db.runInteraction( + rooms_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_rooms" ) return 1 @@ -191,21 +190,19 @@ def _get_next_batch(txn): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = yield self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) if is_in_room: - is_public = yield self.is_room_world_readable_or_publicly_joinable( + is_public = await self.is_room_world_readable_or_publicly_joinable( room_id ) - users_with_profile = yield defer.ensureDeferred( - state.get_current_users_in_room(room_id) - ) + users_with_profile = await state.get_current_users_in_room(room_id) user_ids = set(users_with_profile) # Update each user in the user directory. for user_id, profile in users_with_profile.items(): - yield self.update_profile_in_user_dir( + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -219,7 +216,7 @@ def _get_next_batch(txn): to_insert.add(user_id) if to_insert: - yield self.add_users_in_public_rooms(room_id, to_insert) + await self.add_users_in_public_rooms(room_id, to_insert) to_insert.clear() else: for user_id in user_ids: @@ -239,22 +236,24 @@ def _get_next_batch(txn): # If it gets too big, stop and write to the database # to prevent storing too much in RAM. if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: - yield self.add_users_who_share_private_room( + await self.add_users_who_share_private_room( room_id, to_insert ) to_insert.clear() if to_insert: - yield self.add_users_who_share_private_room(room_id, to_insert) + await self.add_users_who_share_private_room(room_id, to_insert) to_insert.clear() # We've finished a room. Delete it from the table. - yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + await self.db_pool.simple_delete_one( + TEMP_TABLE + "_rooms", {"room_id": room_id} + ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -267,13 +266,12 @@ def _get_next_batch(txn): return processed_event_count - @defer.inlineCallbacks - def _populate_user_directory_process_users(self, progress, batch_size): + async def _populate_user_directory_process_users(self, progress, batch_size): """ If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -299,13 +297,13 @@ def _get_next_batch(txn): return users_to_work_on - users_to_work_on = yield self.db.runInteraction( + users_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more users -- complete the transaction. if not users_to_work_on: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -316,26 +314,27 @@ def _get_next_batch(txn): ) for user_id in users_to_work_on: - profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) - yield self.update_profile_in_user_dir( + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) # We've finished processing a user. Delete it from the table. - yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + await self.db_pool.simple_delete_one( + TEMP_TABLE + "_users", {"user_id": user_id} + ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) return len(users_to_work_on) - @defer.inlineCallbacks - def is_room_world_readable_or_publicly_joinable(self, room_id): + async def is_room_world_readable_or_publicly_joinable(self, room_id): """Check if the room is either world_readable or publically joinable """ @@ -345,20 +344,20 @@ def is_room_world_readable_or_publicly_joinable(self, room_id): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = yield self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) if hist_vis_ev: if hist_vis_ev.content.get("history_visibility") == "world_readable": return True @@ -371,7 +370,7 @@ def update_profile_in_user_dir(self, user_id, display_name, avatar_url): """ def _update_profile_in_user_dir_txn(txn): - new_entry = self.db.simple_upsert_txn( + new_entry = self.db_pool.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -445,7 +444,7 @@ def _update_profile_in_user_dir_txn(txn): ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -458,7 +457,7 @@ def _update_profile_in_user_dir_txn(txn): txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db.runInteraction( + return self.db_pool.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) @@ -472,7 +471,7 @@ def add_users_who_share_private_room(self, room_id, user_id_tuples): """ def _add_users_who_share_room_txn(txn): - self.db.simple_upsert_many_txn( + self.db_pool.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -484,7 +483,7 @@ def _add_users_who_share_room_txn(txn): value_values=None, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) @@ -499,7 +498,7 @@ def add_users_in_public_rooms(self, room_id, user_ids): def _add_users_in_public_rooms_txn(txn): - self.db.simple_upsert_many_txn( + self.db_pool.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -508,7 +507,7 @@ def _add_users_in_public_rooms_txn(txn): value_values=None, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) @@ -523,13 +522,13 @@ def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.db.runInteraction( + return self.db_pool.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @cached() def get_user_in_directory(self, user_id): - return self.db.simple_select_one( + return self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -538,7 +537,7 @@ def get_user_in_directory(self, user_id): ) def update_user_directory_stream_pos(self, stream_id): - return self.db.simple_update_one( + return self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -552,47 +551,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(UserDirectoryStore, self).__init__(database, db_conn, hs) def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + return self.db_pool.runInteraction( + "remove_from_user_dir", _remove_from_user_dir_txn + ) - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): + async def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self.db.simple_select_onecol( + user_ids_share_pub = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self.db.simple_select_onecol( + user_ids_share_priv = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -615,28 +615,27 @@ def remove_user_who_share_room(self, user_id, room_id): """ def _remove_user_who_share_room_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) - @defer.inlineCallbacks - def get_user_dir_rooms_user_is_in(self, user_id): + async def get_user_dir_rooms_user_is_in(self, user_id): """ Returns the rooms that a user is in. @@ -646,14 +645,14 @@ def get_user_dir_rooms_user_is_in(self, user_id): Returns: list: user_id """ - rows = yield self.db.simple_select_onecol( + rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self.db.simple_select_onecol( + pub_rows = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -664,42 +663,15 @@ def get_user_dir_rooms_user_is_in(self, user_id): users.update(rows) return list(users) - @defer.inlineCallbacks - def get_rooms_in_common_for_users(self, user_id, other_user_id): - """Given two user_ids find out the list of rooms they share. - """ - sql = """ - SELECT room_id FROM ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) AS f1 INNER JOIN ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) f2 USING (room_id) - """ - - rows = yield self.db.execute( - "get_rooms_in_common_for_users", None, sql, user_id, other_user_id - ) - - return [room_id for room_id, in rows] - def get_user_directory_stream_pos(self): - return self.db.simple_select_one_onecol( + return self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", desc="get_user_directory_stream_pos", ) - @defer.inlineCallbacks - def search_user_dir(self, user_id, search_term, limit): + async def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -796,8 +768,8 @@ def search_user_dir(self, user_id, search_term, limit): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self.db.execute( - "search_user_dir", self.db.cursor_to_dict, sql, *args + results = await self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args ) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py similarity index 93% rename from synapse/storage/data_stores/main/user_erasure_store.py rename to synapse/storage/databases/main/user_erasure_store.py index d3038ff06d1b..ab6cb2c1f665 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -31,7 +31,7 @@ def is_user_erased(self, user_id): Returns: Deferred[bool]: True if the user has requested erasure """ - return self.db.simple_select_onecol( + return self.db_pool.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ def are_users_erased(self, user_ids): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self.db.simple_select_many_batch( + rows = yield self.db_pool.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -88,7 +88,7 @@ def f(txn): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db.runInteraction("mark_user_erased", f) + return self.db_pool.runInteraction("mark_user_erased", f) def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. @@ -110,4 +110,4 @@ def f(txn): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db.runInteraction("mark_user_not_erased", f) + return self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/databases/state/__init__.py similarity index 87% rename from synapse/storage/data_stores/state/__init__.py rename to synapse/storage/databases/state/__init__.py index 86e09f622994..c90d0228993c 100644 --- a/synapse/storage/data_stores/state/__init__.py +++ b/synapse/storage/databases/state/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 +from synapse.storage.databases.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py similarity index 91% rename from synapse/storage/data_stores/state/bg_updates.py rename to synapse/storage/databases/state/bg_updates.py index be1fe97d79cb..139085b67292 100644 --- a/synapse/storage/data_stores/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -15,10 +15,8 @@ import logging -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter @@ -62,7 +60,7 @@ def _count_state_group_hops_txn(self, txn, state_group): count = 0 while next_group: - next_group = self.db.simple_select_one_onecol_txn( + next_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -165,7 +163,7 @@ def _get_state_groups_from_groups_txn( ): break - next_group = self.db.simple_select_one_onecol_txn( + next_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -182,24 +180,23 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, index_name="state_groups_room_id_idx", table="state_groups", columns=["room_id"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): + async def _background_deduplicate_state(self, progress, batch_size): """This background update will slowly deduplicate state by reencoding them as deltas. """ @@ -212,7 +209,7 @@ def _background_deduplicate_state(self, progress, batch_size): batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self.db.execute( + rows = await self.db_pool.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -282,13 +279,13 @@ def reindex_txn(txn): if prev_state.get(key, None) != value } - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="state_group_edges", values={ @@ -297,13 +294,13 @@ def reindex_txn(txn): }, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -324,25 +321,24 @@ def reindex_txn(txn): "max_group": max_group, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) return False, batch_size - finished, result = yield self.db.runInteraction( + finished, result = await self.db_pool.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) if finished: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) return result * BATCH_SIZE_SCALE_FACTOR - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): + async def _background_index_state(self, progress, batch_size): def reindex_txn(conn): conn.rollback() if isinstance(self.database_engine, PostgresEngine): @@ -365,8 +361,10 @@ def reindex_txn(conn): ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.db.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + await self.db_pool.updates._end_background_update( + self.STATE_GROUP_INDEX_UPDATE_NAME + ) return 1 diff --git a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql rename to synapse/storage/databases/state/schema/delta/23/drop_state_index.sql diff --git a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql b/synapse/storage/databases/state/schema/delta/30/state_stream.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/30/state_stream.sql rename to synapse/storage/databases/state/schema/delta/30/state_stream.sql diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql rename to synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql diff --git a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql rename to synapse/storage/databases/state/schema/delta/35/add_state_index.sql diff --git a/synapse/storage/data_stores/state/schema/delta/35/state.sql b/synapse/storage/databases/state/schema/delta/35/state.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/35/state.sql rename to synapse/storage/databases/state/schema/delta/35/state.sql diff --git a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql rename to synapse/storage/databases/state/schema/delta/35/state_dedupe.sql diff --git a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py rename to synapse/storage/databases/state/schema/delta/47/state_group_seq.py diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql rename to synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/databases/state/schema/full_schemas/54/full.sql similarity index 100% rename from synapse/storage/data_stores/state/schema/full_schemas/54/full.sql rename to synapse/storage/databases/state/schema/full_schemas/54/full.sql diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres similarity index 100% rename from synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres rename to synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/databases/state/store.py similarity index 94% rename from synapse/storage/data_stores/state/store.py rename to synapse/storage/databases/state/store.py index 7dada7f75f83..7f104ad93640 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -21,8 +21,8 @@ from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator @@ -53,7 +53,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateGroupDataStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the @@ -112,7 +112,7 @@ def get_state_group_delta(self, state_group): """ def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( + prev_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -123,7 +123,7 @@ def _get_state_group_delta_txn(txn): if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self.db.simple_select_list_txn( + delta_ids = self.db_pool.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -135,7 +135,7 @@ def _get_state_group_delta_txn(txn): {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.db.runInteraction( + return self.db_pool.runInteraction( "get_state_group_delta", _get_state_group_delta_txn ) @@ -156,7 +156,7 @@ async def _get_state_groups_from_groups( chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = await self.db.runInteraction( + res = await self.db_pool.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -393,7 +393,7 @@ def _store_state_group_txn(txn): state_group = self._state_group_seq_gen.get_next_id_txn(txn) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -402,7 +402,7 @@ def _store_state_group_txn(txn): # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( + is_in_db = self.db_pool.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -417,13 +417,13 @@ def _store_state_group_txn(txn): potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -438,7 +438,7 @@ def _store_state_group_txn(txn): ], ) else: - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -484,7 +484,7 @@ def _store_state_group_txn(txn): return state_group - return self.db.runInteraction("store_state_group", _store_state_group_txn) + return self.db_pool.runInteraction("store_state_group", _store_state_group_txn) def purge_unreferenced_state_groups( self, room_id: str, state_groups_to_delete @@ -499,7 +499,7 @@ def purge_unreferenced_state_groups( to delete. """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -511,7 +511,7 @@ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete) "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -538,15 +538,15 @@ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete) curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -583,7 +583,7 @@ async def get_previous_state_groups( A mapping from state group to previous state group. """ - rows = await self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -602,7 +602,7 @@ def purge_room_state(self, room_id, state_groups_to_delete): state_groups_to_delete (list[int]): State groups to delete """ - return self.db.runInteraction( + return self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, @@ -613,7 +613,7 @@ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -624,7 +624,7 @@ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -635,7 +635,7 @@ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="state_groups", column="id", diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 4a164834d967..f15b95e633e7 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,8 +29,8 @@ from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.data_stores import DataStores -from synapse.storage.data_stores.main.events import DeltaState +from synapse.storage.databases import Databases +from synapse.storage.databases.main.events import DeltaState from synapse.types import StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -179,7 +179,7 @@ class EventsPersistenceStorage(object): current state and forward extremity changes. """ - def __init__(self, hs, stores: DataStores): + def __init__(self, hs, stores: Databases): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 9cc3b51fe6a1..1c5f305132b9 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -47,8 +47,8 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): - """Prepares a database for usage. Will either create all necessary tables +def prepare_database(db_conn, database_engine, config, databases=["main", "state"]): + """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. If `config` is None then prepare_database will assert that no upgrade is @@ -60,8 +60,8 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already - data_stores (list[str]): The name of the data stores that will be used - with this database. Defaults to all data stores. + databases (list[str]): The name of the databases that will be used + with this physical database. Defaults to all databases. """ try: @@ -87,10 +87,10 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta upgraded, database_engine, config, - data_stores=data_stores, + databases=databases, ) else: - _setup_new_database(cur, database_engine, data_stores=data_stores) + _setup_new_database(cur, database_engine, databases=databases) # check if any of our configured dynamic modules want a database if config is not None: @@ -103,9 +103,9 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta raise -def _setup_new_database(cur, database_engine, data_stores): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas, including schemas from the given data +def _setup_new_database(cur, database_engine, databases): + """Sets up the physical database by finding a base set of "full schemas" and + then applying any necessary deltas, including schemas from the given data stores. The "full_schemas" directory has subdirectories named after versions. This @@ -138,8 +138,8 @@ def _setup_new_database(cur, database_engine, data_stores): Args: cur (Cursor): a database cursor database_engine (DatabaseEngine) - data_stores (list[str]): The names of the data stores to instantiate - on the given database. + databases (list[str]): The names of the databases to instantiate + on the given physical database. """ # We're about to set up a brand new database so we check that its @@ -176,13 +176,13 @@ def _setup_new_database(cur, database_engine, data_stores): directories.extend( os.path.join( dir_path, - "data_stores", - data_store, + "databases", + database, "schema", "full_schemas", str(max_current_ver), ) - for data_store in data_stores + for database in databases ) directory_entries = [] @@ -219,7 +219,7 @@ def _setup_new_database(cur, database_engine, data_stores): upgraded=False, database_engine=database_engine, config=None, - data_stores=data_stores, + databases=databases, is_empty=True, ) @@ -231,10 +231,10 @@ def _upgrade_existing_database( upgraded, database_engine, config, - data_stores, + databases, is_empty=False, ): - """Upgrades an existing database. + """Upgrades an existing physical database. Delta files can either be SQL stored in *.sql files, or python modules in *.py. @@ -285,8 +285,8 @@ def _upgrade_existing_database( config (synapse.config.homeserver.HomeServerConfig|None): None if we are initialising a blank database, otherwise the application config - data_stores (list[str]): The names of the data stores to instantiate - on the given database. + databases (list[str]): The names of the databases to instantiate + on the given physical database. is_empty (bool): Is this a blank database? I.e. do we need to run the upgrade portions of the delta scripts. """ @@ -303,8 +303,8 @@ def _upgrade_existing_database( # some of the deltas assume that config.server_name is set correctly, so now # is a good time to run the sanity check. - if not is_empty and "main" in data_stores: - from synapse.storage.data_stores.main import check_database_before_upgrade + if not is_empty and "main" in databases: + from synapse.storage.databases.main import check_database_before_upgrade check_database_before_upgrade(cur, database_engine, config) @@ -330,11 +330,9 @@ def _upgrade_existing_database( # First we find the directories to search in delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) directories = [delta_dir] - for data_store in data_stores: + for database in databases: directories.append( - os.path.join( - dir_path, "data_stores", data_store, "schema", "delta", str(v) - ) + os.path.join(dir_path, "databases", database, "schema", "delta", str(v)) ) # Used to check if we have any duplicate file names diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 787cebfbec75..e2ddd012904c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -20,7 +20,7 @@ from typing_extensions import Deque -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -239,7 +239,7 @@ class MultiWriterIdGenerator: def __init__( self, db_conn, - db: Database, + db: DatabasePool, instance_name: str, table: str, instance_column: str, diff --git a/synapse/types.py b/synapse/types.py index 238b93806448..9e580f4295ca 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -13,11 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import re import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, TypeVar +from typing import Any, Dict, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -33,7 +34,7 @@ T_co = TypeVar("T_co", covariant=True) - class Collection(Iterable[T_co], Container[T_co], Sized): + class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore __slots__ = () @@ -141,6 +142,9 @@ def get_localpart_from_id(string): return string[1:idx] +DS = TypeVar("DS", bound="DomainSpecificString") + + class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -151,6 +155,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 'domain' : The domain part of the name """ + __metaclass__ = abc.ABCMeta + + SIGIL = abc.abstractproperty() # type: str # type: ignore + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) @@ -166,7 +174,7 @@ def __deepcopy__(self, memo): return self @classmethod - def from_string(cls, s: str): + def from_string(cls: Type[DS], s: str) -> DS: """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( @@ -190,12 +198,12 @@ def from_string(cls, s: str): # names on one HS return cls(localpart=parts[0], domain=domain) - def to_string(self): + def to_string(self) -> str: """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) @classmethod - def is_valid(cls, s): + def is_valid(cls: Type[DS], s: str) -> bool: try: cls.from_string(s) return True @@ -235,8 +243,9 @@ class GroupID(DomainSpecificString): SIGIL = "+" @classmethod - def from_string(cls, s): - group_id = super(GroupID, cls).from_string(s) + def from_string(cls: Type[DS], s: str) -> DS: + group_id = super().from_string(s) # type: DS # type: ignore + if not group_id.localpart: raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index c63256d3bd04..b3f76428b65e 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -17,6 +17,7 @@ import re import attr +from canonicaljson import json from twisted.internet import defer, task @@ -24,6 +25,9 @@ logger = logging.getLogger(__name__) +# Create a custom encoder to reduce the whitespace produced by JSON encoding. +json_encoder = json.JSONEncoder(separators=(",", ":")) + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 9b09c08b8981..c2d72a82cfdf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -192,7 +192,7 @@ def set(self, key, value, callback=None): callbacks = [callback] if callback else [] self.check_thread() observable = ObservableDeferred(value, consumeErrors=True) - observer = defer.maybeDeferred(observable.observe) + observer = observable.observe() entry = CacheEntry(deferred=observable, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index eab78dd2567f..0e445e01d773 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -63,5 +63,8 @@ def _handle_frozendict(obj): ) -# A JSONEncoder which is capable of encoding frozendicts without barfing -frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) +# A JSONEncoder which is capable of encoding frozendicts without barfing. +# Additionally reduce the whitespace produced by JSON encoding. +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, separators=(",", ":"), +) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index ec61e1442339..13775b43f99c 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,14 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from functools import wraps +from typing import Any, Callable, Optional, TypeVar, cast from prometheus_client import Counter -from twisted.internet import defer - from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge @@ -60,29 +58,37 @@ sub_metrics=["real_time_max", "real_time_sum"], ) +T = TypeVar("T", bound=Callable[..., Any]) -def measure_func(name=None): - def wrapper(func): - block_name = func.__name__ if name is None else name - if inspect.iscoroutinefunction(func): +def measure_func(name: Optional[str] = None) -> Callable[[T], T]: + """ + Used to decorate an async function with a `Measure` context manager. + + Usage: + + @measure_func() + async def foo(...): + ... - @wraps(func) - async def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = await func(self, *args, **kwargs) - return r + Which is analogous to: - else: + async def foo(...): + with Measure(...): + ... + + """ + + def wrapper(func: T) -> T: + block_name = func.__name__ if name is None else name - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r - return measured_func + return cast(T, measured_func) return wrapper diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 8794317caab1..919988d3bcfc 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -15,8 +15,6 @@ import logging import random -from twisted.internet import defer - import synapse.logging.context from synapse.api.errors import CodeMessageException @@ -54,8 +52,7 @@ def __init__(self, retry_last_ts, retry_interval, destination): self.destination = destination -@defer.inlineCallbacks -def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): +async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) Example usage: try: - limiter = yield get_retry_limiter(destination, clock, store) + limiter = await get_retry_limiter(destination, clock, store) with limiter: - response = yield do_request() + response = await do_request() except NotRetryingDestination: # We aren't ready to retry that destination. raise @@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) failure_ts = None retry_last_ts, retry_interval = (0, 0) - retry_timings = yield store.get_destination_retry_timings(destination) + retry_timings = await store.get_destination_retry_timings(destination) if retry_timings: failure_ts = retry_timings["failure_ts"] @@ -222,10 +219,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self.failure_ts is None: self.failure_ts = retry_last_ts - @defer.inlineCallbacks - def store_retry_timings(): + async def store_retry_timings(): try: - yield self.store.set_destination_retry_timings( + await self.store.set_destination_retry_timings( self.destination, self.failure_ts, retry_last_ts, diff --git a/synmark/__init__.py b/synmark/__init__.py index afe4fad8cb4e..53698bd5ab5a 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None): stor = hs.get_datastore() # Run the database background updates. - if hasattr(stor.db.updates, "do_next_background_update"): - while not await stor.db.updates.has_completed_background_updates(): - await stor.db.updates.do_next_background_update(1) + if hasattr(stor.db_pool.updates, "do_next_background_update"): + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) def cleanup(): for i in cleanup_tasks: diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 0bfb86bf1f6d..5d45689c8c3b 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -62,12 +62,15 @@ def setUp(self): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) + self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) self.store.is_support_user = Mock(return_value=defer.succeed(False)) @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -76,23 +79,25 @@ def test_get_user_by_req_user_valid_token(self): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = {"name": self.test_user, "token_id": "ditto"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -103,7 +108,7 @@ def test_get_user_by_req_appservice_valid_token(self): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -123,7 +128,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" @@ -142,25 +147,25 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") @@ -168,11 +173,11 @@ def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -185,7 +190,11 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + # This just needs to return a truth-y value. + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -204,20 +213,22 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) self.failureResultOf(d, AuthError) @defer.inlineCallbacks def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org", "device_id": "device"} + return_value=defer.succeed( + {"name": "@baldrick:matrix.org", "device_id": "device"} + ) ) user_id = "@baldrick:matrix.org" @@ -241,8 +252,8 @@ def test_get_user_from_macaroon(self): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value={"is_guest": True}) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -282,16 +293,20 @@ def test_cannot_use_regular_token_as_guest(self): def get_user(tok): if token != tok: - return None - return { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + return defer.succeed(None) + return defer.succeed( + { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock(return_value={"is_guest": False}) + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) # check the token works request = Mock(args={}) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 4e67503cf06b..1fab1d6b6902 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -375,8 +375,10 @@ def test_filter_presence_match(self): event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -396,8 +398,10 @@ def test_filter_presence_no_match(self): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart + "2", filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -412,8 +416,10 @@ def test_filter_room_state_match(self): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events=events) @@ -430,8 +436,10 @@ def test_filter_room_state_no_match(self): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events) @@ -465,8 +473,10 @@ def test_add_filter(self): self.assertEquals( user_filter_json, ( - yield self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 + yield defer.ensureDeferred( + self.datastore.get_user_filter( + user_localpart=user_localpart, filter_id=0 + ) ) ), ) @@ -479,8 +489,10 @@ def test_get_filter(self): user_localpart=user_localpart, user_filter=user_filter_json ) - filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) self.assertEquals(filter.get_filter_json(), user_filter_json) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index b8ca11871695..9bd515080c7d 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -79,9 +79,11 @@ def test_join_too_large(self): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( @@ -110,9 +112,11 @@ def test_join_too_large_admin(self): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( @@ -148,9 +152,11 @@ def test_join_too_large_once_joined(self): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable(None) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -204,9 +210,11 @@ def test_join_too_large_no_admin(self): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( @@ -234,9 +242,11 @@ def test_join_too_large_admin(self): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=make_awaitable(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 628f7d8db031..2a0b7c1b56ec 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -120,7 +120,7 @@ def test_query_room_alias_exists(self): self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_association_from_room_alias.return_value = defer.succeed( + self.mock_store.get_association_from_room_alias.return_value = make_awaitable( Mock(room_id=room_id, servers=servers) ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 6d45c4b2332f..e364b1bd6237 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,6 +22,7 @@ from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester +from tests.test_utils import make_awaitable from tests.unittest import override_config from .. import unittest @@ -187,7 +188,7 @@ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=defer.succeed(False)) + self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -199,8 +200,8 @@ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=defer.succeed(1)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -214,8 +215,8 @@ def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=defer.succeed(2)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 5dc37956439e..0e666492f629 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -15,7 +15,7 @@ from synapse.rest import admin from synapse.rest.client.v1 import login, room -from synapse.storage.data_stores.main import stats +from synapse.storage.databases.main import stats from tests import unittest @@ -42,16 +42,16 @@ def _add_background_updates(self): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms_2", @@ -61,7 +61,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ def _add_background_updates(self): ) def get_all_room_state(self): - return self.store.db.simple_select_list( + return self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ def _get_current_stats(self, stats_type, stat_id): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store.db.simple_select_one( + self.store.db_pool.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -109,10 +109,10 @@ def _perform_background_initial_update(self): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_initial_room(self): @@ -146,10 +146,10 @@ def test_initial_room(self): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r = self.get_success(self.get_all_room_state()) @@ -186,9 +186,9 @@ def test_initial_earliest_token(self): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_update_one( + self.store.db_pool.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -196,17 +196,17 @@ def test_initial_earliest_token(self): ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now, before the table is actually ingested, add some more events. @@ -217,7 +217,7 @@ def test_initial_earliest_token(self): # Now do the initial ingestion. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms_2", @@ -226,7 +226,7 @@ def test_initial_earliest_token(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -236,12 +236,12 @@ def test_initial_earliest_token(self): ) ) - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) self.reactor.advance(86401) @@ -703,15 +703,15 @@ def test_incomplete_stats(self): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -723,9 +723,9 @@ def test_incomplete_stats(self): # now do the background updates - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms_2", @@ -735,7 +735,7 @@ def test_incomplete_stats(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -745,7 +745,7 @@ def test_incomplete_stats(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -756,10 +756,10 @@ def test_incomplete_stats(self): ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r1stats_complete = self._get_current_stats("room", r1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5878f7417517..64afd581bc42 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import register_federation_servlets @@ -115,7 +116,7 @@ def prepare(self, reactor, clock, hs): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( (0, []) ) @@ -126,10 +127,10 @@ def get_received_txn_response(*args): self.room_members = [] - def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - return defer.succeed(None) + return None hs.get_auth().check_user_in_room = check_user_in_room @@ -151,7 +152,7 @@ def get_users_in_room(room_id): self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 23fcc372ddef..31ed89a5cd6b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -339,7 +339,7 @@ def _compress_shared(self, shared): def get_users_in_public_rooms(self): r = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -350,7 +350,7 @@ def get_users_in_public_rooms(self): def get_users_who_share_private_rooms(self): return self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -362,10 +362,10 @@ def _add_background_updates(self): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -374,7 +374,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -384,7 +384,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -394,7 +394,7 @@ def _add_background_updates(self): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -437,10 +437,10 @@ def test_initial(self): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() @@ -476,10 +476,10 @@ def test_initial_share_all_users(self): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 06575ba0a6a5..ae60874ec3c2 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -65,7 +65,7 @@ def prepare(self, reactor, clock, hs): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db = hs.get_datastore().db + self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler @@ -198,7 +198,7 @@ def setUp(self): self.streamer = self.hs.get_replication_streamer() store = self.hs.get_datastore() - self.database = store.db + self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -254,7 +254,7 @@ def make_worker_hs( ) store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool + store.db_pool._db_pool = self.database_pool._db_pool repl_handler = ReplicationCommandHandler(worker_hs) client = ClientReplicationStreamProtocol( diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index cec1cf928f9f..408c568a277c 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -566,7 +566,7 @@ def _is_purged(self, room_id): "state_groups_state", ): count = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", @@ -667,7 +667,7 @@ def test_purge_room(self): "state_groups_state", ): count = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index f16eef15f742..17d0aae2e9b2 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -20,6 +20,8 @@ from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import HttpResponseException, ResourceLimitError @@ -335,7 +337,9 @@ def test_register_mau_limit_reached(self): store = self.hs.get_datastore() # Set monthly active users to the limit - store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit self.get_failure( @@ -588,7 +592,7 @@ def test_create_user_mau_limit_reached_active_admin(self): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -628,7 +632,7 @@ def test_create_user_mau_limit_reached_passive_admin(self): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e54ffea1505d..0b191d13c619 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -144,7 +144,9 @@ def _test_retention_event_purged(self, room_id, increment): # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( - message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + message_handler.get_room_data( + self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False + ) ) # Send a first event to the room. This is the event we'll want to be purged at the diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 8df58b4a6333..ace0a3c08d55 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -70,8 +70,8 @@ def setUp(self): profile_handler=self.mock_handler, ) - def _get_user_by_req(request=None, allow_guest=False): - return defer.succeed(synapse.types.create_requester(myid)) + async def _get_user_by_req(request=None, allow_guest=False): + return synapse.types.create_requester(myid) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5ccda8b2bd63..ef6b775ed2b8 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -23,8 +23,6 @@ from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus @@ -51,8 +49,8 @@ def make_homeserver(self, reactor, clock): self.hs.get_federation_handler = Mock(return_value=Mock()) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None self.hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 18260bb90e2e..94d2bf2eb172 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -46,7 +46,7 @@ def make_homeserver(self, reactor, clock): hs.get_handlers().federation_handler = Mock() - def get_user_by_access_token(token=None, allow_guest=False): + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, @@ -55,8 +55,8 @@ def get_user_by_access_token(token=None, allow_guest=False): hs.get_auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 51941f99f91c..8933b560d2cb 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -165,26 +165,6 @@ def send_event( return channel.json_body - def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") - ) - render(request, self.resource, self.hs.get_reactor()) - - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) - ) - - return channel.json_body - def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 7deaf5b24a48..53a43038f0f0 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -116,8 +116,8 @@ def test_POST_user_valid(self): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) + @override_config({"enable_registration": False}) def test_POST_disabled_registration(self): - self.hs.config.enable_registration = False request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a31e44c97e15..fa3a3ec1bddd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import read_marker, sync +from synapse.rest.client.v2_alpha import sync from tests import unittest from tests.server import TimedOutException @@ -324,156 +324,3 @@ def test_sync_backwards_typing(self): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.render(request) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: True): - """Syncs and compares the unread count with the expected value.""" - - request, channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, - ) - self.render(request) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py new file mode 100644 index 000000000000..2d021f656542 --- /dev/null +++ b/tests/rest/test_health.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.rest.health import HealthResource + +from tests import unittest + + +class HealthCheckTests(unittest.HomeserverTestCase): + def setUp(self): + super().setUp() + + # replace the JsonResource with a HealthResource. + self.resource = HealthResource() + + def test_health(self): + request, channel = self.make_request("GET", "/health", shorthand=False) + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 7f70353b0d37..2858d1355829 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,6 +27,7 @@ ) from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -79,7 +80,9 @@ def prepare(self, reactor, clock, hs): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) + self._rlsn._store.get_tags_for_room = Mock( + side_effect=lambda user_id, room_id: make_awaitable({}) + ) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): @@ -258,7 +261,7 @@ def prepare(self, reactor, clock, hs): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=1000) + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 5a50e4fdd454..319e2c232567 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -323,7 +323,7 @@ def prepare(self, reactor, clock, hs): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ def prepare(self, reactor, clock, hs): ) ) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ def test_upsert_many(self): value_values = [["hello"], ["there"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ def test_upsert_many(self): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ def test_upsert_many(self): value_values = [["bleb"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ def test_upsert_many(self): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ef296e7dab14..98b74890d5bc 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,11 +24,11 @@ from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -178,14 +178,14 @@ def _set_last_txn(self, as_id, txn_id): @defer.inlineCallbacks def test_get_appservice_state_none(self): service = Mock(id="999") - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) @defer.inlineCallbacks def test_get_appservice_state_up(self): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks @@ -194,13 +194,13 @@ def test_get_appservice_state_down(self): yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) @defer.inlineCallbacks def test_get_appservices_by_state_none(self): - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) @@ -339,7 +339,7 @@ def test_complete_appservice_txn_existing_in_state_table(self): def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) @defer.inlineCallbacks @@ -349,14 +349,14 @@ def test_get_oldest_unsent_txn(self): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) @@ -366,8 +366,8 @@ def test_get_appservices_by_state_single(self): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) @@ -379,8 +379,8 @@ def test_get_appservices_by_state_multiple(self): yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) self.assertEquals( @@ -391,7 +391,7 @@ def test_get_appservices_by_state_multiple(self): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(TestTransactionStore, self).__init__(database, db_conn, hs) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 940b16612997..2efbc97c2e62 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,7 +9,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + self.updates = ( + self.hs.get_datastore().db_pool.updates + ) # type: BackgroundUpdater # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -29,7 +31,7 @@ def test_do_background_update(self): store = self.hs.get_datastore() self.get_success( - store.db.simple_insert( + store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) @@ -40,7 +42,7 @@ def test_do_background_update(self): def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield store.db.runInteraction( + yield store.db_pool.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index b589506c6043..efcaeef1e77e 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest @@ -57,7 +57,7 @@ def runWithConnection(func, *args, **kwargs): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db._db_pool = self.db_pool self.datastore = SQLBaseStore(db, None, hs) @@ -66,7 +66,7 @@ def runWithConnection(func, *args, **kwargs): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( + yield self.datastore.db_pool.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -78,7 +78,7 @@ def test_insert_1col(self): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( + yield self.datastore.db_pool.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -93,7 +93,7 @@ def test_select_one_1col(self): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db.simple_select_one_onecol( + value = yield self.datastore.db_pool.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -107,7 +107,7 @@ def test_select_one_3col(self): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db.simple_select_one( + ret = yield self.datastore.db_pool.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -123,7 +123,7 @@ def test_select_one_missing(self): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db.simple_select_one( + ret = yield self.datastore.db_pool.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -138,7 +138,7 @@ def test_select_list(self): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db.simple_select_list( + ret = yield self.datastore.db_pool.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -151,7 +151,7 @@ def test_select_list(self): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( + yield self.datastore.db_pool.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -166,7 +166,7 @@ def test_update_one_1col(self): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( + yield self.datastore.db_pool.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -181,7 +181,7 @@ def test_update_one_4cols(self): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_delete_one( + yield self.datastore.db_pool.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 43425c969a0d..3fab5a524829 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -47,12 +47,12 @@ def run_background_update(self): """ # Make sure we don't clash with in progress updates. self.assertTrue( - self.store.db.updates._all_done, "Background updates are still ongoing" + self.store.db_pool.updates._all_done, "Background updates are still ongoing" ) schema_path = os.path.join( prepare_database.dir_path, - "data_stores", + "databases", "main", "schema", "delta", @@ -64,19 +64,19 @@ def run_delta_file(txn): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "test_delete_forward_extremities", run_delta_file ) ) # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_soft_failed_extremities_handled_correctly(self): diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 3b483bc7f018..224ea6fd79d3 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -86,7 +86,7 @@ def test_insert_new_client_ip_none_device_id(self): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -117,7 +117,7 @@ def test_insert_new_client_ip_none_device_id(self): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -204,10 +204,10 @@ def test_updating_monthly_active_user_when_space(self): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -225,7 +225,7 @@ def test_devices_last_seen_bg_update(self): # But clear the associated entry in devices table self.get_success( - self.store.db.simple_update( + self.store.db_pool.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -252,7 +252,7 @@ def test_devices_last_seen_bg_update(self): # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -263,14 +263,14 @@ def test_devices_last_seen_bg_update(self): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # We should now get the correct result again @@ -293,10 +293,10 @@ def test_devices_last_seen_bg_update(self): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -315,7 +315,7 @@ def test_old_user_ips_pruned(self): # We should see that in the DB result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -341,7 +341,7 @@ def test_old_user_ips_pruned(self): # We should get no results. result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index c2539b353ace..87ed8f8cd1b4 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -34,7 +34,9 @@ def setUp(self): @defer.inlineCallbacks def test_store_new_device(self): - yield self.store.store_device("user_id", "device_id", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name") + ) res = yield self.store.get_device("user_id", "device_id") self.assertDictContainsSubset( @@ -48,11 +50,17 @@ def test_store_new_device(self): @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self.store.store_device("user_id", "device1", "display_name 1") - yield self.store.store_device("user_id", "device2", "display_name 2") - yield self.store.store_device("user_id2", "device3", "display_name 3") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) - res = yield self.store.get_devices_by_user("user_id") + res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) self.assertDictContainsSubset( { @@ -76,13 +84,13 @@ def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id - yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"] + yield defer.ensureDeferred( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "somehost", -1, limit=100 + now_stream_id, device_updates = yield defer.ensureDeferred( + self.store.get_device_updates_by_remote("somehost", -1, limit=100) ) # Check original device_ids are contained within these updates @@ -99,19 +107,23 @@ def _check_devices_in_updates(self, expected_device_ids, device_updates): @defer.inlineCallbacks def test_update_device(self): - yield self.store.store_device("user_id", "device_id", "display_name 1") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name 1") + ) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield self.store.update_device("user_id", "device_id") + yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do the update - yield self.store.update_device( - "user_id", "device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "device_id", new_display_name="display_name 2" + ) ) # check it worked @@ -121,7 +133,9 @@ def test_update_device(self): @defer.inlineCallbacks def test_update_unknown_device(self): with self.assertRaises(synapse.api.errors.StoreError) as cm: - yield self.store.update_device( - "user_id", "unknown_device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "unknown_device_id", new_display_name="display_name 2" + ) ) self.assertEqual(404, cm.exception.code) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 4e128e10478e..daac947cb2b8 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -34,8 +34,10 @@ def setUp(self): @defer.inlineCallbacks def test_room_to_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertEquals( @@ -45,24 +47,36 @@ def test_room_to_alias(self): @defer.inlineCallbacks def test_alias_to_room(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertObjectHasAttributes( {"room_id": self.room.to_string(), "servers": ["test"]}, - (yield self.store.get_association_from_room_alias(self.alias)), + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ), ) @defer.inlineCallbacks def test_delete_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) - room_id = yield self.store.delete_room_alias(self.alias) + room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (yield self.store.get_association_from_room_alias(self.alias)) + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ) ) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 398d546280ba..d57cdffd8ba4 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -30,11 +30,13 @@ def test_key_without_device_name(self): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) yield self.store.set_e2e_device_keys("user", "device", now, json) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -45,7 +47,7 @@ def test_reupload_key(self): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) changed = yield self.store.set_e2e_device_keys("user", "device", now, json) self.assertTrue(changed) @@ -61,9 +63,13 @@ def test_get_key_with_device_name(self): json = {"key": "value"} yield self.store.set_e2e_device_keys("user", "device", now, json) - yield self.store.store_device("user", "device", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user", "device", "display_name") + ) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -75,18 +81,18 @@ def test_get_key_with_device_name(self): def test_multiple_devices(self): now = 1470174257070 - yield self.store.store_device("user1", "device1", None) - yield self.store.store_device("user1", "device2", None) - yield self.store.store_device("user2", "device1", None) - yield self.store.store_device("user2", "device2", None) + yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) - res = yield self.store.get_e2e_device_keys( - (("user1", "device1"), ("user2", "device2")) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 3aeec0dc0f52..d4c3b867e350 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -56,7 +56,9 @@ def insert_event(txn, i): ) for i in range(0, 20): - self.get_success(self.store.db.runInteraction("insert", insert_event, i)) + self.get_success( + self.store.db_pool.runInteraction("insert", insert_event, i) + ) # this should get the last ten r = self.get_success(self.store.get_prev_events_for_room(room_id)) @@ -81,13 +83,13 @@ def insert_event(txn, i, room_id): for i in range(0, 20): self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room1) + self.store.db_pool.runInteraction("insert", insert_event, i, room1) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room2) + self.store.db_pool.runInteraction("insert", insert_event, i, room2) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room3) + self.store.db_pool.runInteraction("insert", insert_event, i, room3) ) # Test simple case @@ -164,7 +166,7 @@ def insert_event(txn, event_id, stream_ordering): depth = depth_map[event_id] - self.store.db.simple_insert_txn( + self.store.db_pool.simple_insert_txn( txn, table="events", values={ @@ -179,7 +181,7 @@ def insert_event(txn, event_id, stream_ordering): }, ) - self.store.db.simple_insert_many_txn( + self.store.db_pool.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -192,7 +194,7 @@ def insert_event(txn, event_id, stream_ordering): for event_id in auth_graph: next_stream_ordering += 1 self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "insert", insert_event, event_id, next_stream_ordering ) ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 2b1580feebaa..857db071d4b1 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -60,7 +60,7 @@ def test_count_aggregation(self): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.db.runInteraction( + counts = yield self.store.db_pool.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -81,7 +81,7 @@ def _inject_actions(stream, action): event.event_id, {user_id: action} ) ) - yield self.store.db.runInteraction( + yield self.store.db_pool.runInteraction( "", self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -89,12 +89,12 @@ def _inject_actions(stream, action): ) def _rotate(stream): - return self.store.db.runInteraction( + return self.store.db_pool.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.db.runInteraction( + return self.store.db_pool.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -123,7 +123,7 @@ def _mark_read(stream, depth): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.db.simple_delete( + yield self.store.db_pool.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -142,7 +142,7 @@ def _mark_read(stream, depth): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db.simple_insert( + return self.store.db_pool.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 55e9ecf2641c..e845410dae81 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -14,7 +14,7 @@ # limitations under the License. -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.unittest import HomeserverTestCase @@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db = self.store.db # type: Database + self.db_pool = self.store.db_pool # type: DatabasePool - self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) def _setup_db(self, txn): txn.execute("CREATE SEQUENCE foobar_seq") @@ -47,7 +47,7 @@ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator def _create(conn): return MultiWriterIdGenerator( conn, - self.db, + self.db_pool, instance_name=instance_name, table="foobar", instance_column="instance_name", @@ -55,7 +55,7 @@ def _create(conn): sequence_name="foobar_seq", ) - return self.get_success(self.db.runWithConnection(_create)) + return self.get_success(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): def _insert(txn): @@ -65,7 +65,7 @@ def _insert(txn): (instance_name,), ) - self.get_success(self.db.runInteraction("test_single_instance", _insert)) + self.get_success(self.db_pool.runInteraction("test_single_instance", _insert)) def test_empty(self): """Test an ID generator against an empty database gives sensible @@ -178,7 +178,7 @@ def _get_next_txn(txn): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token("master"), 7) - self.get_success(self.db.runInteraction("test", _get_next_txn)) + self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9c04e9257731..9870c748834f 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,6 +19,7 @@ from synapse.api.constants import UserTypes from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -78,7 +79,7 @@ def test_initialise_reserved_users(self): # XXX why are we doing this here? this function is only run at startup # so it is odd to re-run it here. self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) ) @@ -204,7 +205,7 @@ def test_reap_monthly_active_users_reserved_users(self): self.store.user_add_threepid(user, "email", email, now, now) ) - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) @@ -230,7 +231,9 @@ def test_populate_monthly_users_is_guest(self): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -238,7 +241,9 @@ def test_populate_monthly_users_is_guest(self): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -251,7 +256,9 @@ def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -280,7 +287,7 @@ def test_get_reserved_real_user_account(self): ] self.hs.config.mau_limits_reserved_threepids = threepids - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) @@ -293,8 +300,12 @@ def test_get_reserved_real_user_account(self): self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), len(threepids)) @@ -333,7 +344,9 @@ def test_track_monthly_users_without_cap(self): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0f0e1cd09b6e..1ea35d60c11c 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -251,6 +251,10 @@ def build(self, prev_event_ids): def room_id(self): return self._base_builder.room_id + @property + def type(self): + return self._base_builder.type + event_1, context_1 = self.get_success( self.event_creation_handler.create_new_client_event( EventIdManglingBuilder( @@ -343,7 +347,7 @@ def test_redact_censor(self): ) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -361,7 +365,7 @@ def test_redact_censor(self): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 71a40a0a4911..840db6607286 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -58,8 +58,10 @@ def test_register(self): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) result = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -74,11 +76,15 @@ def test_add_tokens(self): def test_user_delete_access_tokens(self): # add some tokens yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + ) ) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) # now delete some diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index f282921538c1..17c9da483867 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -179,10 +179,10 @@ def prepare(self, reactor, clock, homeserver): def test_can_rerun_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now let's create a room, which will insert a membership @@ -192,7 +192,7 @@ def test_can_rerun_update(self): # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -203,12 +203,12 @@ def test_can_rerun_update(self): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 6a545d2eb028..ecfafe68a965 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -40,7 +40,7 @@ def setUp(self): def test_search_user_dir(self): # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(1, len(r["results"])) self.assertDictEqual( @@ -51,7 +51,7 @@ def test_search_user_dir(self): def test_search_user_dir_all_users(self): self.hs.config.user_directory_search_all_users = True try: - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(2, len(r["results"])) self.assertDictEqual( diff --git a/tests/test_federation.py b/tests/test_federation.py index c2f12c2741e3..f2fa42bfb925 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed @@ -10,6 +25,7 @@ from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver +from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): @@ -173,7 +189,7 @@ def query_user_devices(destination, user_id): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) + store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. diff --git a/tests/unittest.py b/tests/unittest.py index 68d2586efd42..d0bba3ddefd5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -241,20 +241,16 @@ def setUp(self): if hasattr(self, "user_id"): if self.hijack_auth: - def get_user_by_access_token(token=None, allow_guest=False): - return succeed( - { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - ) - - def get_user_by_req(request, allow_guest=False, rights="access"): - return succeed( - create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + async def get_user_by_req(request, allow_guest=False, rights="access"): + return create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -422,8 +418,8 @@ def setup_test_homeserver(self, *args, **kwargs): async def run_bg_updates(): with LoggingContext("run_bg_updates", request="run_bg_updates-1"): - while not await stor.db.updates.has_completed_background_updates(): - await stor.db.updates.do_next_background_update(1) + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -571,7 +567,7 @@ def add_extremity(self, room_id, event_id): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db.simple_insert( + self.hs.get_datastore().db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e348694ad7f..bc42ffce880c 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request self.pump(1) @@ -36,18 +34,14 @@ def test_new_destination(self): with limiter: pass - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) def test_limiter(self): """General test case which walks through the process of a failing request""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -58,29 +52,22 @@ def test_limiter(self): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], failure_ts) self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) # now if we try again we should get a failure - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - self.failureResultOf(d, NotRetryingDestination) + self.get_failure( + get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + ) # # advance the clock and try again # self.pump(MIN_RETRY_INTERVAL) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -91,12 +78,7 @@ def test_limiter(self): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], retry_ts) self.assertGreaterEqual( @@ -110,9 +92,7 @@ def test_limiter(self): # one more go, with success # self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) with limiter: @@ -121,7 +101,5 @@ def test_limiter(self): # wait for the update to land self.pump() - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) diff --git a/tox.ini b/tox.ini index 2b1db0f7f780..e5413eb1102a 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ commands = mypy \ synapse/appservice \ synapse/config \ synapse/event_auth.py \ + synapse/events/builder.py \ synapse/events/spamcheck.py \ synapse/federation \ synapse/handlers/auth.py \ @@ -186,6 +187,7 @@ commands = mypy \ synapse/handlers/directory.py \ synapse/handlers/federation.py \ synapse/handlers/identity.py \ + synapse/handlers/message.py \ synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ synapse/handlers/room_member.py \ @@ -198,19 +200,23 @@ commands = mypy \ synapse/logging/ \ synapse/metrics \ synapse/module_api \ + synapse/notifier.py \ synapse/push/pusherpool.py \ synapse/push/push_rule_evaluator.py \ synapse/replication \ synapse/rest \ + synapse/server.py \ synapse/server_notices \ synapse/spam_checker_api \ - synapse/storage/data_stores/main/ui_auth.py \ + synapse/storage/databases/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ synapse/storage/state.py \ synapse/storage/util \ synapse/streams \ + synapse/types.py \ synapse/util/caches/stream_change_cache.py \ + synapse/util/metrics.py \ tests/replication \ tests/test_utils \ tests/rest/client/v2_alpha/test_auth.py \