Skip to content
Snippets Groups Projects
Commit 8437e238 authored by Erik Johnston's avatar Erik Johnston
Browse files

Port SyncHandler to async/await

parent d085a8a0
No related branches found
No related tags found
No related merge requests found
......@@ -16,8 +16,6 @@
import logging
import random
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
......@@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@log_function
def get_stream(
async def get_stream(
self,
auth_user_id,
pagin_config,
......@@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler):
"""
if room_id:
blocked = yield self.store.is_room_blocked(room_id)
blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
# send any outstanding server notices to the user.
yield self._server_notices_sender.on_user_syncing(auth_user_id)
await self._server_notices_sender.on_user_syncing(auth_user_id)
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing(
context = await presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence
)
with context:
......@@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
events, tokens = await self.notifier.get_events_for(
auth_user,
pagin_config,
timeout,
......@@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.state.get_current_users_in_room(
users = await self.state.get_current_users_in_room(
event.room_id
)
states = yield presence_handler.get_states(users, as_event=True)
states = await presence_handler.get_states(users, as_event=True)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
ev = await presence_handler.get_state(
UserID.from_string(event.state_key), as_event=True
)
to_add.append(ev)
......@@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
chunks = yield self._event_serializer.serialize_events(
chunks = await self._event_serializer.serialize_events(
events,
time_now,
as_client_event=as_client_event,
......@@ -151,8 +148,7 @@ class EventHandler(BaseHandler):
super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage()
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
async def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
Args:
......@@ -167,15 +163,15 @@ class EventHandler(BaseHandler):
AuthError if the user does not have the rights to inspect this
event.
"""
event = yield self.store.get_event(event_id, check_room_id=room_id)
event = await self.store.get_event(event_id, check_room_id=room_id)
if not event:
return None
users = yield self.store.get_users_in_room(event.room_id)
users = await self.store.get_users_in_room(event.room_id)
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
filtered = await filter_events_for_client(
self.storage, user.to_string(), [event], is_peeking=is_peeking
)
......
This diff is collapsed.
......@@ -304,8 +304,7 @@ class Notifier(object):
without waking up any of the normal user event streams"""
self.notify_replication()
@defer.inlineCallbacks
def wait_for_events(
async def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
):
"""Wait until the callback returns a non empty response or the
......@@ -313,9 +312,9 @@ class Notifier(object):
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
current_token = yield self.event_sources.get_current_token()
current_token = await self.event_sources.get_current_token()
if room_ids is None:
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
......@@ -344,11 +343,11 @@ class Notifier(object):
self.hs.get_reactor(),
)
with PreserveLoggingContext():
yield listener.deferred
await listener.deferred
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
result = await callback(prev_token, current_token)
if result:
break
......@@ -364,12 +363,11 @@ class Notifier(object):
# This happened if there was no timeout or if the timeout had
# already expired.
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
result = await callback(prev_token, current_token)
return result
@defer.inlineCallbacks
def get_events_for(
async def get_events_for(
self,
user,
pagination_config,
......@@ -391,15 +389,14 @@ class Notifier(object):
"""
from_token = pagination_config.from_token
if not from_token:
from_token = yield self.event_sources.get_current_token()
from_token = await self.event_sources.get_current_token()
limit = pagination_config.limit
room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
async def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token))
......@@ -415,7 +412,7 @@ class Notifier(object):
if only_keys and name not in only_keys:
continue
new_events, new_key = yield source.get_new_events(
new_events, new_key = await source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
limit=limit,
......@@ -425,7 +422,7 @@ class Notifier(object):
)
if name == "room":
new_events = yield filter_events_for_client(
new_events = await filter_events_for_client(
self.storage,
user.to_string(),
new_events,
......@@ -461,7 +458,7 @@ class Notifier(object):
user_id_for_stream,
)
result = yield self.wait_for_events(
result = await self.wait_for_events(
user_id_for_stream,
timeout,
check_for_updates,
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from functools import wraps
......@@ -64,12 +65,22 @@ def measure_func(name=None):
def wrapper(func):
block_name = func.__name__ if name is None else name
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
if inspect.iscoroutinefunction(func):
@wraps(func)
async def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
else:
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
return measured_func
......
......@@ -12,54 +12,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.handlers.sync import SyncConfig
from synapse.types import UserID
import tests.unittest
import tests.utils
from tests.utils import setup_test_homeserver
class SyncTestCase(tests.unittest.TestCase):
class SyncTestCase(tests.unittest.HomeserverTestCase):
""" Tests Sync Handler. """
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.sync_handler = SyncHandler(self.hs)
def prepare(self, reactor, clock, hs):
self.hs = hs
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
yield self.store.upsert_monthly_active_user(user_id1)
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
self.hs.config.hs_disabled = True
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
......
......@@ -18,6 +18,7 @@
import gc
import hashlib
import hmac
import inspect
import logging
import time
......@@ -25,7 +26,7 @@ from mock import Mock
from canonicaljson import json
from twisted.internet.defer import Deferred, succeed
from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
......@@ -415,6 +416,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
if inspect.isawaitable(d):
d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
......@@ -424,6 +427,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
if inspect.isawaitable(d):
d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment