Skip to content
Snippets Groups Projects
Unverified Commit f97f9485 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Split fetching device keys and signatures into two transactions (#8233)

I think this is simpler (and moves stuff out of the db threads)
parent 208e1d3e
No related branches found
No related tags found
No related merge requests found
Refactor queries for device keys and cross-signatures.
......@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
......@@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])
# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
# cross-signing sigs on this device.
# dict from (signing user_id)->(signing device_id)->sig
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
class EndToEndKeyWorkerStore(SQLBaseStore):
......@@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.
"""Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
included.
Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
......@@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
# get the (user_id, device_id) tuples to look up cross-signatures for
signature_query = (
(user_id, device_id)
for user_id, dev in result.items()
for device_id, d in dev.items()
if d is not None
)
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
target_device_signatures = target_device_result.signatures
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature
log_kv(result)
return result
def _get_e2e_device_keys_and_signatures_txn(
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
The results include the device's keys and self-signatures, but *not* any
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = []
query_params = []
signature_query_clauses = []
signature_query_params = []
if include_all_devices is False:
include_deleted_devices = False
......@@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)
if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)
signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)
query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)
sql = (
"SELECT user_id, device_id, "
......@@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
# get signatures on the device
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)
return result
txn.execute(signature_sql, signature_query_params)
rows = self.db_pool.cursor_to_dict(txn)
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]
target_user_result = result.get(target_user_id)
if not target_user_result:
continue
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
Returns signatures made by the owners of the devices.
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
"""
signature_query_clauses = []
signature_query_params = []
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
for (user_id, device_id) in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signing_user_signatures[signing_key_id] = signature
signature_query_params.extend([user_id, device_id, user_id])
return result
signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)
txn.execute(signature_sql, signature_query_params)
return txn.fetchall()
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment