Skip to content
Snippets Groups Projects
Commit 1135193d authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Validate group ids when parsing

May as well do it whenever we parse a Group ID. We check the sigil and basic
structure here so it makes sense to check the grammar in the same place.
parent 29812c62
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import logging import logging
from synapse import types
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from twisted.internet import defer from twisted.internet import defer
...@@ -696,9 +695,11 @@ class GroupsServerHandler(object): ...@@ -696,9 +695,11 @@ class GroupsServerHandler(object):
def create_group(self, group_id, user_id, content): def create_group(self, group_id, user_id, content):
group = yield self.check_group_is_ours(group_id) group = yield self.check_group_is_ours(group_id)
_validate_group_id(group_id)
logger.info("Attempting to create group with ID: %r", group_id) logger.info("Attempting to create group with ID: %r", group_id)
# parsing the id into a GroupID validates it.
group_id_obj = GroupID.from_string(group_id)
if group: if group:
raise SynapseError(400, "Group already exists") raise SynapseError(400, "Group already exists")
...@@ -708,7 +709,7 @@ class GroupsServerHandler(object): ...@@ -708,7 +709,7 @@ class GroupsServerHandler(object):
raise SynapseError( raise SynapseError(
403, "Only server admin can create group on this server", 403, "Only server admin can create group on this server",
) )
localpart = GroupID.from_string(group_id).localpart localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix): if not localpart.startswith(self.hs.config.group_creation_prefix):
raise SynapseError( raise SynapseError(
400, 400,
...@@ -784,15 +785,3 @@ def _parse_visibility_from_contents(content): ...@@ -784,15 +785,3 @@ def _parse_visibility_from_contents(content):
is_public = True is_public = True
return is_public return is_public
def _validate_group_id(group_id):
"""Validates the group ID is valid for creation on this home server
"""
localpart = GroupID.from_string(group_id).localpart
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
)
...@@ -161,6 +161,23 @@ class GroupID(DomainSpecificString): ...@@ -161,6 +161,23 @@ class GroupID(DomainSpecificString):
"""Structure representing a group ID.""" """Structure representing a group ID."""
SIGIL = "+" SIGIL = "+"
@classmethod
def from_string(cls, s):
group_id = super(GroupID, cls).from_string(s)
if not group_id.localpart:
raise SynapseError(
400,
"Group ID cannot be empty",
)
if contains_invalid_mxid_characters(group_id.localpart):
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
)
return group_id
mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits) mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
......
...@@ -17,7 +17,7 @@ from tests import unittest ...@@ -17,7 +17,7 @@ from tests import unittest
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias, GroupID
mock_homeserver = HomeServer(hostname="my.domain") mock_homeserver = HomeServer(hostname="my.domain")
...@@ -60,3 +60,25 @@ class RoomAliasTestCase(unittest.TestCase): ...@@ -60,3 +60,25 @@ class RoomAliasTestCase(unittest.TestCase):
room = RoomAlias("channel", "my.domain") room = RoomAlias("channel", "my.domain")
self.assertEquals(room.to_string(), "#channel:my.domain") self.assertEquals(room.to_string(), "#channel:my.domain")
class GroupIDTestCase(unittest.TestCase):
def test_parse(self):
group_id = GroupID.from_string("+group/=_-.123:my.domain")
self.assertEqual("group/=_-.123", group_id.localpart)
self.assertEqual("my.domain", group_id.domain)
def test_validate(self):
bad_ids = [
"$badsigil:domain",
"+:empty",
] + [
"+group" + c + ":domain" for c in "A%?æ£"
]
for id_string in bad_ids:
try:
GroupID.from_string(id_string)
self.fail("Parsing '%s' should raise exception" % id_string)
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment