Skip to content

Commit 2a321ba

Browse files
authored
Issue one time keys in upload order (#17903)
Currently, one-time-keys are issued in a somewhat random order. (In practice, they are issued according to the lexicographical order of their key IDs.) That can lead to a situation where a client gives up hope of a given OTK ever being used, whilst it is still on the server. Related: element-hq/element-meta#2356
1 parent eda735e commit 2a321ba

File tree

5 files changed

+116
-8
lines changed

5 files changed

+116
-8
lines changed

changelog.d/17903.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures.

synapse/handlers/e2e_keys.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ async def claim_local_one_time_keys(
615615
3. Attempt to fetch fallback keys from the database.
616616
617617
Args:
618-
local_query: An iterable of tuples of (user ID, device ID, algorithm).
618+
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
619619
always_include_fallback_keys: True to always include fallback keys.
620620
621621
Returns:

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ def __init__(
9999
unique=True,
100100
)
101101

102+
self.db_pool.updates.register_background_index_update(
103+
update_name="add_otk_ts_added_index",
104+
index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
105+
table="e2e_one_time_keys_json",
106+
columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
107+
)
108+
102109

103110
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
104111
def __init__(
@@ -1122,7 +1129,7 @@ async def claim_e2e_one_time_keys(
11221129
"""Take a list of one time keys out of the database.
11231130
11241131
Args:
1125-
query_list: An iterable of tuples of (user ID, device ID, algorithm).
1132+
query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
11261133
11271134
Returns:
11281135
A tuple (results, missing) of:
@@ -1310,9 +1317,14 @@ def _claim_e2e_one_time_key_simple(
13101317
OTK was found.
13111318
"""
13121319

1320+
# Return the oldest keys from this device (based on `ts_added_ms`).
1321+
# Doing so means that keys are issued in the same order they were uploaded,
1322+
# which reduces the chances of a client expiring its copy of a (private)
1323+
# key while the public key is still on the server, waiting to be issued.
13131324
sql = """
13141325
SELECT key_id, key_json FROM e2e_one_time_keys_json
13151326
WHERE user_id = ? AND device_id = ? AND algorithm = ?
1327+
ORDER BY ts_added_ms
13161328
LIMIT ?
13171329
"""
13181330

@@ -1354,13 +1366,22 @@ def _claim_e2e_one_time_keys_bulk(
13541366
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
13551367
for each OTK claimed.
13561368
"""
1369+
# Find, delete, and return the oldest keys from each device (based on
1370+
# `ts_added_ms`).
1371+
#
1372+
# Doing so means that keys are issued in the same order they were uploaded,
1373+
# which reduces the chances of a client expiring its copy of a (private)
1374+
# key while the public key is still on the server, waiting to be issued.
13571375
sql = """
13581376
WITH claims(user_id, device_id, algorithm, claim_count) AS (
13591377
VALUES ?
13601378
), ranked_keys AS (
13611379
SELECT
13621380
user_id, device_id, algorithm, key_id, claim_count,
1363-
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
1381+
ROW_NUMBER() OVER (
1382+
PARTITION BY (user_id, device_id, algorithm)
1383+
ORDER BY ts_added_ms
1384+
) AS r
13641385
FROM e2e_one_time_keys_json
13651386
JOIN claims USING (user_id, device_id, algorithm)
13661387
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
--
2+
-- This file is licensed under the Affero General Public License (AGPL) version 3.
3+
--
4+
-- Copyright (C) 2024 New Vector, Ltd
5+
--
6+
-- This program is free software: you can redistribute it and/or modify
7+
-- it under the terms of the GNU Affero General Public License as
8+
-- published by the Free Software Foundation, either version 3 of the
9+
-- License, or (at your option) any later version.
10+
--
11+
-- See the GNU Affero General Public License for more details:
12+
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
13+
14+
15+
-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
16+
-- efficiently be issued in the same order they were uploaded.
17+
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
18+
(8803, 'add_otk_ts_added_index', '{}');

tests/handlers/test_e2e_keys.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,30 @@ def test_change_one_time_keys(self) -> None:
151151
def test_claim_one_time_key(self) -> None:
152152
local_user = "@boris:" + self.hs.hostname
153153
device_id = "xyz"
154-
keys = {"alg1:k1": "key1"}
155-
156154
res = self.get_success(
157155
self.handler.upload_keys_for_user(
158-
local_user, device_id, {"one_time_keys": keys}
156+
local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
159157
)
160158
)
161159
self.assertDictEqual(
162160
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
163161
)
164162

165-
res2 = self.get_success(
163+
# Keys should be returned in the order they were uploaded. To test, advance time
164+
# a little, then upload a second key with an earlier key ID; it should get
165+
# returned second.
166+
self.reactor.advance(1)
167+
res = self.get_success(
168+
self.handler.upload_keys_for_user(
169+
local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
170+
)
171+
)
172+
self.assertDictEqual(
173+
res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
174+
)
175+
176+
# now claim both keys back. They should be in the same order
177+
res = self.get_success(
166178
self.handler.claim_one_time_keys(
167179
{local_user: {device_id: {"alg1": 1}}},
168180
self.requester,
@@ -171,12 +183,27 @@ def test_claim_one_time_key(self) -> None:
171183
)
172184
)
173185
self.assertEqual(
174-
res2,
186+
res,
175187
{
176188
"failures": {},
177189
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
178190
},
179191
)
192+
res = self.get_success(
193+
self.handler.claim_one_time_keys(
194+
{local_user: {device_id: {"alg1": 1}}},
195+
self.requester,
196+
timeout=None,
197+
always_include_fallback_keys=False,
198+
)
199+
)
200+
self.assertEqual(
201+
res,
202+
{
203+
"failures": {},
204+
"one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
205+
},
206+
)
180207

181208
def test_claim_one_time_key_bulk(self) -> None:
182209
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
@@ -336,6 +363,47 @@ def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
336363
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
337364
)
338365

366+
def test_claim_one_time_key_bulk_ordering(self) -> None:
367+
"""Keys returned by the bulk claim call should be returned in the correct order"""
368+
369+
# Alice has lots of keys, uploaded in a specific order
370+
alice = f"@alice:{self.hs.hostname}"
371+
alice_dev = "alice_dev_1"
372+
373+
self.get_success(
374+
self.handler.upload_keys_for_user(
375+
alice,
376+
alice_dev,
377+
{"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
378+
)
379+
)
380+
# Advance time by 1s, to ensure that there is a difference in upload time.
381+
self.reactor.advance(1)
382+
self.get_success(
383+
self.handler.upload_keys_for_user(
384+
alice,
385+
alice_dev,
386+
{"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
387+
)
388+
)
389+
390+
# Now claim some, and check we get the right ones.
391+
claim_res = self.get_success(
392+
self.handler.claim_one_time_keys(
393+
{alice: {alice_dev: {"alg1": 2}}},
394+
self.requester,
395+
timeout=None,
396+
always_include_fallback_keys=False,
397+
)
398+
)
399+
# We should get the first-uploaded keys, even though they have later key ids.
400+
# We should get a random set of two of k20, k21, k22.
401+
self.assertEqual(claim_res["failures"], {})
402+
claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
403+
self.assertEqual(len(claimed_keys), 2)
404+
for key_id in claimed_keys.keys():
405+
self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
406+
339407
def test_fallback_key(self) -> None:
340408
local_user = "@boris:" + self.hs.hostname
341409
device_id = "xyz"

0 commit comments

Comments
 (0)