Skip to content
Snippets Groups Projects
Commit 2223204e authored by Mark Haines's avatar Mark Haines
Browse files

Hook push rules up to the replication API

parent a1cf9e3b
No related branches found
No related tags found
No related merge requests found
......@@ -36,6 +36,7 @@ STREAM_NAMES = (
("receipts",),
("user_account_data", "room_account_data", "tag_account_data",),
("backfill",),
("push_rules",),
)
......@@ -63,6 +64,7 @@ class ReplicationResource(Resource):
* "room_account_data: Per room per user account data.
* "tag_account_data": Per room per user tags.
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
The API takes two additional query parameters:
......@@ -117,14 +119,16 @@ class ReplicationResource(Resource):
def current_replication_token(self):
stream_token = yield self.sources.get_current_token()
backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
defer.returnValue(_ReplicationToken(
stream_token.room_stream_id,
room_stream_token,
int(stream_token.presence_key),
int(stream_token.typing_key),
int(stream_token.receipt_key),
int(stream_token.account_data_key),
backfill_token,
push_rules_token,
))
@request_handler
......@@ -146,6 +150,7 @@ class ReplicationResource(Resource):
yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
self.streams(writer, current_token)
logger.info("Replicated %d rows", writer.total)
......@@ -277,6 +282,21 @@ class ReplicationResource(Resource):
"position", "user_id", "room_id", "tags"
))
@defer.inlineCallbacks
def push_rules(self, writer, current_token, limit):
current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit
)
writer.write_header_and_rows("push_rules", rows, (
"position", "stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions"
))
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
......@@ -307,12 +327,16 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules"
))):
__slots__ = []
def __new__(cls, *args):
if len(args) == 1:
return cls(*(int(value) for value in args[0].split("_")))
streams = [int(value) for value in args[0].split("_")]
if len(streams) < len(cls._fields):
streams.extend([0] * (len(cls._fields) - len(streams)))
return cls(*streams)
else:
return super(_ReplicationToken, cls).__new__(cls, *args)
......
......@@ -412,6 +412,12 @@ class PushRuleStore(SQLBaseStore):
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token()
class RuleNotFoundException(Exception):
pass
......
......@@ -35,7 +35,8 @@ class ReplicationResourceCase(unittest.TestCase):
"send_message",
]),
)
self.user = UserID.from_string("@seeing:red")
self.user_id = "@seeing:red"
self.user = UserID.from_string(self.user_id)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
......@@ -101,7 +102,7 @@ class ReplicationResourceCase(unittest.TestCase):
event_id = yield self.send_text_message(room_id, "Hello, World")
get = self.get(receipts="-1")
yield self.hs.get_handlers().receipts_handler.received_client_receipt(
room_id, "m.read", self.user.to_string(), event_id
room_id, "m.read", self.user_id, event_id
)
code, body = yield get
self.assertEquals(code, 200)
......@@ -129,6 +130,7 @@ class ReplicationResourceCase(unittest.TestCase):
test_timeout_room_account_data = _test_timeout("room_account_data")
test_timeout_tag_account_data = _test_timeout("tag_account_data")
test_timeout_backfill = _test_timeout("backfill")
test_timeout_push_rules = _test_timeout("push_rules")
@defer.inlineCallbacks
def send_text_message(self, room_id, message):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment