Skip to content
Snippets Groups Projects
Commit 777d9914 authored by Kegan Dougal's avatar Kegan Dougal
Browse files

Implement filter algorithm. Add basic event type unit tests to assert it works.

parent 50de1eaa
No related branches found
No related tags found
No related merge requests found
......@@ -91,8 +91,57 @@ class Filtering(object):
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition["rooms"] if "rooms" in definition else None
reject_rooms = (
definition["not_rooms"] if "not_rooms" in definition else None
)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = (
definition["senders"] if "senders" in definition else None
)
reject_senders = (
definition["not_senders"] if "not_senders" in definition else None
)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type
def _check_valid_filter(self, user_filter):
"""Check if the provided filter is valid.
......
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from tests import unittest
from twisted.internet import defer
......@@ -27,6 +27,7 @@ from synapse.server import HomeServer
user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id")
class FilteringTestCase(unittest.TestCase):
......@@ -55,6 +56,48 @@ class FilteringTestCase(unittest.TestCase):
self.datastore = hs.get_datastore()
def test_definition_include_literal_types(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_include_wildcard_types(self):
definition = {
"types": ["m.*", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_exclude_unknown_types(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="now.for.something.completely.different",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
@defer.inlineCallbacks
def test_add_filter(self):
user_filter = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment