Skip to content
Snippets Groups Projects
Commit d2d8ed48 authored by Erik Johnston's avatar Erik Johnston
Browse files

Optimise caches with single key

parent 5d829042
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache
......@@ -163,10 +164,6 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
......@@ -312,7 +309,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
)
def get_cache_key(args, kwargs):
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.
......@@ -330,13 +327,29 @@ class CacheDescriptor(_CacheDescriptorBase):
else:
yield self.arg_defaults[nm]
# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if self.num_args == 1:
nm = self.arg_names[0]
def get_cache_key(args, kwargs):
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return self.arg_defaults[nm]
else:
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
cache_key = tuple(get_cache_key(args, kwargs))
cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
......@@ -363,6 +376,11 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
# If our cache_key is a string, try to convert to ascii to save
# a bit of space in large caches
if isinstance(cache_key, basestring):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()
......@@ -372,10 +390,16 @@ class CacheDescriptor(_CacheDescriptorBase):
else:
return observer
wrapped.invalidate = cache.invalidate
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment