Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Limit the number of in-flight /keys/query requests from a single device. #10144

Merged
merged 2 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10144.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Limit the number of in-flight `/keys/query` requests from a single device.
350 changes: 181 additions & 169 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,15 @@ def __init__(self, hs: "HomeServer"):
"client_keys", self.on_federation_query_client_keys
)

# Limit the number of in-flight requests from a single device.
self._query_devices_linearizer = Linearizer(
name="query_devices",
max_count=10,
)

@trace
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict:
"""Handle a device key query from a client

Expand All @@ -105,191 +111,197 @@ async def query_devices(
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
"""

device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]

# separate users by domain.
# make a map from domain to user_id to device_ids
local_query = {}
remote_queries = {}

for user_id, device_ids in device_keys_query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids

set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)

# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys

# Get cached cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)

# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]

# separate users by domain.
# make a map from domain to user_id to device_ids
local_query = {}
remote_queries = {}

for user_id, device_ids in device_keys_query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
query_list.append((user_id, None))

(
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result

# check for missing cross-signing keys.
for user_id in remote_queries.keys():
cached_cross_master = user_id in cross_signing_keys["master_keys"]
cached_cross_selfsigning = (
user_id in cross_signing_keys["self_signing_keys"]
)

# check if we are missing only one of cross-signing master or
# self-signing key, but the other one is cached.
# as we need both, this will issue a federation request.
# if we don't have any of the keys, either the user doesn't have
# cross-signing set up, or the cached device list
# is not (yet) updated.
if cached_cross_master ^ cached_cross_selfsigning:
user_ids_not_in_cache.add(user_id)

# add those users to the list to fetch over federation.
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]

# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache with their device list.
"""

destination_query = remote_queries_not_in_cache[destination]

# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue

if device_list:
continue
remote_queries[user_id] = device_ids

set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)

# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys

room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
# Get cached cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)

# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
(user_id, device_id) for device_id in device_ids
)
else:
user_devices = await self._user_device_resync_client(
user_id=user_id
)

user_devices = user_devices["devices"]
user_results = results.setdefault(user_id, {})
for device in user_devices:
user_results[device["device_id"]] = device["keys"]
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)

if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
query_list.append((user_id, None))

# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
(
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result

# check for missing cross-signing keys.
for user_id in remote_queries.keys():
cached_cross_master = user_id in cross_signing_keys["master_keys"]
cached_cross_selfsigning = (
user_id in cross_signing_keys["self_signing_keys"]
)

try:
remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
# check if we are missing only one of cross-signing master or
# self-signing key, but the other one is cached.
# as we need both, this will issue a federation request.
# if we don't have any of the keys, either the user doesn't have
# cross-signing set up, or the cached device list
# is not (yet) updated.
if cached_cross_master ^ cached_cross_selfsigning:
user_ids_not_in_cache.add(user_id)

# add those users to the list to fetch over federation.
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]

# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache with their device list.
"""

destination_query = remote_queries_not_in_cache[destination]

# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue

if device_list:
continue

room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue

# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
user_devices = await self._user_device_resync_client(
user_id=user_id
)

user_devices = user_devices["devices"]
user_results = results.setdefault(user_id, {})
for device in user_devices:
user_results[device["device_id"]] = device["keys"]
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)

if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return

# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)

for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys
try:
remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)

if "master_keys" in remote_result:
for user_id, key in remote_result["master_keys"].items():
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key
results[user_id] = keys

if "self_signing_keys" in remote_result:
for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key
if "master_keys" in remote_result:
for user_id, key in remote_result["master_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key

except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
if "self_signing_keys" in remote_result:
for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key

await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)

await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)

ret = {"device_keys": results, "failures": failures}
ret = {"device_keys": results, "failures": failures}

ret.update(cross_signing_keys)
ret.update(cross_signing_keys)

return ret
return ret

async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
Expand Down
5 changes: 4 additions & 1 deletion synapse/rest/client/v2_alpha/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,12 @@ def __init__(self, hs):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
return 200, result


Expand Down
Loading