Skip to content
Snippets Groups Projects
Commit d94f682a authored by Erik Johnston's avatar Erik Johnston
Browse files

During room intial sync, only calculate current state once.

parent 76c5a5c2
No related branches found
No related tags found
No related merge requests found
......@@ -89,12 +89,19 @@ class Auth(object):
raise
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id):
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
def check_joined_room(self, room_id, user_id, current_state=None):
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
self._check_joined_room(member, user_id, room_id)
defer.returnValue(member)
......@@ -102,7 +109,7 @@ class Auth(object):
def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id)
for event in curr_state:
for event in curr_state.values():
if event.type == EventTypes.Member:
try:
if UserID.from_string(event.state_key).domain != host:
......
......@@ -35,6 +35,7 @@ class MessageHandler(BaseHandler):
def __init__(self, hs):
super(MessageHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
......@@ -225,7 +226,9 @@ class MessageHandler(BaseHandler):
# TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id)
now = self.clock.time_msec()
defer.returnValue([serialize_event(c, now) for c in current_state])
defer.returnValue(
[serialize_event(c, now) for c in current_state.values()]
)
@defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
......@@ -313,7 +316,7 @@ class MessageHandler(BaseHandler):
)
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state
for c in current_state.values()
]
except:
logger.exception("Failed to get snapshot")
......@@ -329,7 +332,14 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None,
feedback=False):
yield self.auth.check_joined_room(room_id, user_id)
current_state = yield self.state.get_current_state(
room_id=room_id,
)
yield self.auth.check_joined_room(
room_id, user_id,
current_state=current_state
)
# TODO(paul): I wish I was called with user objects not user_id
# strings...
......@@ -337,13 +347,12 @@ class MessageHandler(BaseHandler):
# TODO: These concurrently
time_now = self.clock.time_msec()
state_tuples = yield self.state_handler.get_current_state(room_id)
state = [serialize_event(x, time_now) for x in state_tuples]
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
member_event = (yield self.store.get_room_member(
user_id=user_id,
room_id=room_id
))
member_event = current_state.get((EventTypes.Member, user_id,))
now_token = yield self.hs.get_event_sources().get_current_token()
......@@ -360,7 +369,10 @@ class MessageHandler(BaseHandler):
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
room_members = yield self.store.get_room_members(room_id)
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
]
presence_handler = self.hs.get_handlers().presence_handler
presence = []
......
......@@ -175,9 +175,10 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token,
)
current_state_events = yield self.state_handler.get_current_state(
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
defer.returnValue(RoomSyncResult(
room_id=room_id,
......@@ -347,9 +348,10 @@ class SyncHandler(BaseHandler):
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
current_state_events = yield self.state_handler.get_current_state(
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
state_at_previous_sync = yield self.get_state_at_previous_sync(
room_id, since_token=since_token
......@@ -431,6 +433,7 @@ class SyncHandler(BaseHandler):
joined = True
if joined:
state_delta = yield self.state_handler.get_current_state(room_id)
res = yield self.state_handler.get_current_state(room_id)
state_delta = res.values()
defer.returnValue(state_delta)
......@@ -76,7 +76,7 @@ class StateHandler(object):
defer.returnValue(res[1].get((event_type, state_key)))
return
defer.returnValue(res[1].values())
defer.returnValue(res[1])
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment