Skip to content
Snippets Groups Projects
Unverified Commit c37dad67 authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Improve event caching code (#10119)

Ensure we only load an event from the DB once when the same event is requested multiple times at once.
parent 11540be5
No related branches found
No related tags found
No related merge requests found
Improve event caching mechanism to avoid having multiple copies of an event in memory at a time.
......@@ -14,7 +14,6 @@
import logging
import threading
from collections import namedtuple
from typing import (
Collection,
Container,
......@@ -27,6 +26,7 @@ from typing import (
overload,
)
import attr
from constantly import NamedConstant, Names
from typing_extensions import Literal
......@@ -42,7 +42,11 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.logging.context import (
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
......@@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
......@@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
class EventRedactBehaviour(Names):
......@@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore):
max_size=hs.config.caches.event_cache_size,
)
# Map from event ID to a deferred that will result in a map from event
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
......@@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore):
return events
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
......@@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore):
Args:
event_ids (Iterable[str]): The event_ids of the events to fetch
event_ids: The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events. If False,
allow_rejected: Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns:
Dict[str, _EventCacheEntry]:
map from event id to result
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
event_ids, allow_rejected=allow_rejected
event_ids,
)
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
missing_events_ids = {e for e in event_ids if e not in event_entry_map}
# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
already_fetching: Dict[str, defer.Deferred] = {}
for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
already_fetching[event_id] = deferred.observe()
missing_events_ids.difference_update(already_fetching)
if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
] = ObservableDeferred(defer.Deferred())
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)
event_entry_map.update(missing_events)
event_entry_map.update(missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
if already_fetching:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
already_fetching.values(),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
for result in results:
event_entry_map.update(result)
if not allow_rejected:
event_entry_map = {
event_id: entry
for event_id, entry in event_entry_map.items()
if not entry.event.rejected_reason
}
return event_entry_map
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, _EventCacheEntry]:
"""Fetch events from the caches.
Args:
events (Iterable[str]): list of event_ids to fetch
allow_rejected (bool): Whether to return events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
May return rejected events.
Returns:
dict of event_id -> _EventCacheEntry for each event_id in cache. If
allow_rejected is `False` then there will still be an entry but it
will be `None`
Args:
events: list of event_ids to fetch
update_metrics: Whether to update the cache hit ratio metrics
"""
event_map = {}
......@@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore):
if not ret:
continue
if allow_rejected or not ret.event.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
event_map[event_id] = ret
return event_map
......@@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
async def _get_events_from_db(self, event_ids, allow_rejected=False):
async def _get_events_from_db(
self, event_ids: Iterable[str]
) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
Returned events will be added to the cache for future lookups.
Unknown events are omitted from the response.
Args:
event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events. If False,
rejected events are omitted from the response.
event_ids: The event_ids of the events to fetch
Returns:
Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
events_to_fetch = event_ids
......@@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason = row["rejected_reason"]
if not allow_rejected and rejected_reason:
continue
# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
......
......@@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
event_map = self._get_events_from_cache(
member_event_ids, allow_rejected=False, update_metrics=False
)
event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
missing_member_event_ids = []
for event_id in member_event_ids:
ev_entry = event_map.get(event_id)
if ev_entry:
if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=ev_entry.event.content.get("displayname", None),
......
......@@ -14,7 +14,10 @@
import json
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
......@@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
self.assertEquals(res, {"event10"})
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
class EventCacheTestCase(unittest.HomeserverTestCase):
"""Test that the various layers of event cache works."""
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store: EventsWorkerStore = hs.get_datastore()
self.user = self.register_user("user", "pass")
self.token = self.login(self.user, "pass")
self.room = self.helper.create_room_as(self.user, tok=self.token)
res = self.helper.send(self.room, tok=self.token)
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
self.store._get_event_cache.clear()
def test_simple(self):
"""Test that we cache events that we pull from the DB."""
with LoggingContext("test") as ctx:
self.get_success(self.store.get_event(self.event_id))
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
"""
with LoggingContext("test") as ctx:
d = yieldable_gather_results(
self.store.get_event, [self.event_id, self.event_id]
)
self.get_success(d)
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment