Skip to content
Snippets Groups Projects
Commit 7335f0ad authored by Erik Johnston's avatar Erik Johnston
Browse files

Add ReadWriteLock

parent 2d21d43c
No related branches found
No related tags found
No related merge requests found
......@@ -194,3 +194,85 @@ class Linearizer(object):
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())
class ReadWriteLock(object):
"""A deferred style read write lock.
Example:
with (yield read_write_lock.read("test_key")):
# do some work
"""
# IMPLEMENTATION NOTES
#
# We track the most recent queued reader and writer deferreds (which get
# resolved when they release the lock).
#
# Read: We know its safe to acquire a read lock when the latest writer has
# been resolved. The new reader is appeneded to the list of latest readers.
#
# Write: We know its safe to acquire the write lock when both the latest
# writers and readers have been resolved. The new writer replaces the latest
# writer.
def __init__(self):
# Latest readers queued
self.key_to_current_readers = {}
# Latest writer queued
self.key_to_current_writer = {}
@defer.inlineCallbacks
def read(self, key):
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
curr_readers.add(new_defer)
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
yield curr_writer
@contextmanager
def _ctx_manager():
try:
yield
finally:
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
defer.returnValue(_ctx_manager())
@defer.inlineCallbacks
def write(self, key):
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
# We wait on all latest readers and writer.
to_wait_on = list(curr_readers)
if curr_writer:
to_wait_on.append(curr_writer)
# We can clear the list of current readers since the new writer waits
# for them to finish.
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
yield defer.gatherResults(to_wait_on)
@contextmanager
def _ctx_manager():
try:
yield
finally:
new_defer.callback(None)
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
defer.returnValue(_ctx_manager())
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from synapse.util.async import ReadWriteLock
class ReadWriteLockTestCase(unittest.TestCase):
def _assert_called_before_not_after(self, lst, first_false):
for i, d in enumerate(lst[:first_false]):
self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
for i, d in enumerate(lst[first_false:]):
self.assertFalse(
d.called, msg="%d was unexpectedly true" % (i + first_false)
)
def test_rwlock(self):
rwlock = ReadWriteLock()
key = object()
ds = [
rwlock.read(key), # 0
rwlock.read(key), # 1
rwlock.write(key), # 2
rwlock.write(key), # 3
rwlock.read(key), # 4
rwlock.read(key), # 5
rwlock.write(key), # 6
]
self._assert_called_before_not_after(ds, 2)
with ds[0].result:
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(ds, 2)
with ds[1].result:
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(ds, 3)
with ds[2].result:
self._assert_called_before_not_after(ds, 3)
self._assert_called_before_not_after(ds, 4)
with ds[3].result:
self._assert_called_before_not_after(ds, 4)
self._assert_called_before_not_after(ds, 6)
with ds[5].result:
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(ds, 6)
with ds[4].result:
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(ds, 7)
with ds[6].result:
pass
d = rwlock.write(key)
self.assertTrue(d.called)
with d.result:
pass
d = rwlock.read(key)
self.assertTrue(d.called)
with d.result:
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment