Skip to content

Commit

Permalink
Handle OTK uploads off master (element-hq#17271)
Browse files Browse the repository at this point in the history
And fallback keys uploads. Only device keys need handling on master
  • Loading branch information
erikjohnston authored and Mic92 committed Jun 14, 2024
1 parent 7cc608e commit f4229fd
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 38 deletions.
1 change: 1 addition & 0 deletions changelog.d/17271.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Handle OTK uploads off master.
84 changes: 55 additions & 29 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import (
JsonDict,
JsonMapping,
Expand Down Expand Up @@ -89,6 +90,12 @@ def __init__(self, hs: "HomeServer"):
edu_updater.incoming_signing_key_update,
)

self.device_key_uploader = self.upload_device_keys_for_user
else:
self.device_key_uploader = (
ReplicationUploadKeysForUserRestServlet.make_client(hs)
)

# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
Expand Down Expand Up @@ -796,36 +803,17 @@ async def upload_keys_for_user(
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

time_now = self.clock.time_msec()

# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
await self.device_key_uploader(
user_id=user_id,
device_id=device_id,
keys={"device_keys": device_keys},
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})

one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
log_kv(
Expand Down Expand Up @@ -861,18 +849,56 @@ async def upload_keys_for_user(
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)

result = await self.store.count_e2e_one_time_keys(user_id, device_id)

set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}

@tag_args
async def upload_device_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> None:
"""
Args:
user_id: user whose keys are being uploaded.
device_id: device whose keys are being uploaded.
device_keys: the `device_keys` of an /keys/upload request.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

time_now = self.clock.time_msec()

device_keys = keys["device_keys"]
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])

# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
await self.device_handler.check_device_registered(user_id, device_id)

result = await self.store.count_e2e_one_time_keys(user_id, device_id)

set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}

async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None:
Expand Down
13 changes: 4 additions & 9 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken
from synapse.util.cancellation import cancellable
Expand Down Expand Up @@ -105,13 +104,8 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()

if hs.config.worker.worker_app is None:
# if main process
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
else:
# then a worker
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
self._clock = hs.get_clock()
self._store = hs.get_datastores().main

async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
Expand Down Expand Up @@ -151,9 +145,10 @@ async def on_POST(
400, "To upload keys, you must pass device_id when authenticating"
)

result = await self.key_uploader(
result = await self.e2e_keys_handler.upload_keys_for_user(
user_id=user_id, device_id=device_id, keys=body
)

return 200, result


Expand Down

0 comments on commit f4229fd

Please sign in to comment.