Skip to content
Snippets Groups Projects
Commit 3c2d6c70 authored by Erik Johnston's avatar Erik Johnston
Browse files

Add maybe_awaitable and fix __init__ bugs

parent c3b0fbe9
No related branches found
No related tags found
No related merge requests found
...@@ -44,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet ...@@ -44,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import UserAdminServlet from synapse.rest.admin.users import UserAdminServlet
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -310,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet): ...@@ -310,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet):
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
purge_id = await self.pagination_handler.start_purge_history( purge_id = self.pagination_handler.start_purge_history(
room_id, token, delete_local_events=delete_local_events room_id, token, delete_local_events=delete_local_events
) )
...@@ -480,7 +481,9 @@ class ShutdownRoomRestServlet(RestServlet): ...@@ -480,7 +481,9 @@ class ShutdownRoomRestServlet(RestServlet):
ratelimit=False, ratelimit=False,
) )
aliases_for_room = await self.store.get_aliases_for_room(room_id) aliases_for_room = await maybe_awaitable(
self.store.get_aliases_for_room(room_id)
)
await self.store.update_aliases_for_room( await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id room_id, new_room_id, requester_user_id
......
...@@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union ...@@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union
from six.moves import range from six.moves import range
import attr
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.python import failure from twisted.python import failure
...@@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): ...@@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb) deferred.addCallbacks(success_cb, failure_cb)
return new_d return new_d
@attr.s(slots=True, frozen=True)
class DoneAwaitable(object):
"""Simple awaitable that returns the provided value.
"""
value = attr.ib()
def __await__(self):
return self
def __iter__(self):
return self
def __next__(self):
raise StopIteration(self.value)
def maybe_awaitable(value):
"""Convert a value to an awaitable if not already an awaitable.
"""
if hasattr(value, "__await__"):
return value
return DoneAwaitable(value)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment