Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove unused VerifyKey.expired and .time_added fields #5235

Merged
merged 2 commits into from
May 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5234.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rewrite store_server_verify_key to store several keys at once.
1 change: 1 addition & 0 deletions changelog.d/5235.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove unused VerifyKey.expired and .time_added fields.
62 changes: 14 additions & 48 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,11 @@ def get_server_verify_key_v2_indirect(
raise_from(KeyLookupError("Remote server returned an error"), e)

keys = {}
added_keys = []

responses = query_response["server_keys"]
time_now_ms = self.clock.time_msec()

for response in responses:
for response in query_response["server_keys"]:
if (
u"signatures" not in response
or perspective_name not in response[u"signatures"]
Expand Down Expand Up @@ -492,21 +493,13 @@ def get_server_verify_key_v2_indirect(
)
server_name = response["server_name"]

added_keys.extend(
(server_name, key_id, key) for key_id, key in processed_response.items()
)
keys.setdefault(server_name, {}).update(processed_response)

yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
yield self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys
)

defer.returnValue(keys)
Expand All @@ -519,6 +512,7 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):
if requested_key_id in keys:
continue

time_now_ms = self.clock.time_msec()
try:
response = yield self.client.get_json(
destination=server_name,
Expand Down Expand Up @@ -548,12 +542,13 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids):
requested_ids=[requested_key_id],
response_json=response,
)

yield self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
)
keys.update(response_keys)

yield self.store_keys(
server_name=server_name, from_server=server_name, verify_keys=keys
)
defer.returnValue({server_name: keys})

@defer.inlineCallbacks
Expand Down Expand Up @@ -594,7 +589,6 @@ def process_v2_response(self, from_server, response_json, requested_ids=[]):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key

old_verify_keys = {}
Expand All @@ -603,8 +597,6 @@ def process_v2_response(self, from_server, response_json, requested_ids=[]):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired = key_data["expired_ts"]
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key

server_name = response_json["server_name"]
Expand Down Expand Up @@ -650,32 +642,6 @@ def process_v2_response(self, from_server, response_json, requested_ids=[]):

defer.returnValue(response_keys)

def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
server_name(str): The name of the server the keys are for.
from_server(str): The server the keys were downloaded from.
verify_keys(dict): A mapping of key_id to VerifyKey.
Returns:
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_verify_key,
server_name,
server_name,
key.time_added,
key,
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)


@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
Expand Down
65 changes: 39 additions & 26 deletions synapse/storage/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,38 +84,51 @@ def _txn(txn):

return self.runInteraction("get_server_verify_keys", _txn)

def store_server_verify_key(
self, server_name, from_server, time_now_ms, verify_key
):
"""Stores a NACL verification key for the given server.
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
Args:
server_name (str): The name of the server.
from_server (str): Where the verification key was looked up
time_now_ms (int): The time now in milliseconds
verify_key (nacl.signing.VerifyKey): The NACL verify key.
from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
key_id = "%s:%s" % (verify_key.alg, verify_key.version)

# XXX fix this to not need a lock (#3819)
def _txn(txn):
self._simple_upsert_txn(
txn,
table="server_signature_keys",
keyvalues={"server_name": server_name, "key_id": key_id},
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
"verify_key": db_binary_type(verify_key.encode()),
},
key_values = []
value_values = []
invalidations = []
for server_name, key_id, verify_key in verify_keys:
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
db_binary_type(verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
txn.call_after(
self._get_server_verify_key.invalidate, ((server_name, key_id),)
)

return self.runInteraction("store_server_verify_key", _txn)
invalidations.append((server_name, key_id))

def _invalidate(res):
f = self._get_server_verify_key.invalidate
for i in invalidations:
f((i, ))
return res

return self.runInteraction(
"store_server_verify_keys",
self._simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"verify_key",
),
value_values=value_values,
).addCallback(_invalidate)

def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
Expand Down
14 changes: 12 additions & 2 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,18 @@ def test_verify_json_for_server(self):
kr = keyring.Keyring(self.hs)

key1 = signedjson.key.generate_signing_key(1)
r = self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
key1_id = "%s:%s" % (key1.alg, key1.version)

r = self.hs.datastore.store_server_verify_keys(
"server9",
time.time() * 1000,
[
(
"server9",
key1_id,
signedjson.key.get_verify_key(key1),
),
],
)
self.get_success(r)
json1 = {}
Expand Down
44 changes: 30 additions & 14 deletions tests/storage/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,32 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
store = self.hs.get_datastore()

d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
d = store.store_server_verify_keys(
"from_server",
10,
[
("server1", key_id_1, KEY_1),
("server1", key_id_2, KEY_2),
],
)
self.get_success(d)

d = store.get_server_verify_keys(
[
("server1", "ed25519:key1"),
("server1", "ed25519:key2"),
("server1", "ed25519:key3"),
]
[("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
)
res = self.get_success(d)

self.assertEqual(len(res.keys()), 3)
self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
res1 = res[("server1", key_id_1)]
self.assertEqual(res1, KEY_1)
self.assertEqual(res1.version, "key1")

res2 = res[("server1", key_id_2)]
self.assertEqual(res2, KEY_2)
# version comes from the ID it was stored with
self.assertEqual(res2.version, "KEY_ID_2")

# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
Expand All @@ -60,9 +69,14 @@ def test_cache(self):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"

d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
d = store.store_server_verify_keys(
"from_server",
0,
[
("srv1", key_id_1, KEY_1),
("srv1", key_id_2, KEY_2),
],
)
self.get_success(d)

d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
Expand All @@ -81,7 +95,9 @@ def test_cache(self):
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, new_key_2)]
)
self.get_success(d)

d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
Expand Down