Skip to content
Snippets Groups Projects
Commit 059651ef authored by Paul "LeoNerd" Evans's avatar Paul "LeoNerd" Evans
Browse files

Have the Filtering API return Deferreds, so we can do the Datastore implementation nicely

parent b1503112
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
# TODO(paul)
_filters_for_user = {}
......@@ -24,18 +26,28 @@ class Filtering(object):
super(Filtering, self).__init__()
self.hs = hs
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None)
if not filters or filter_id >= len(filters):
raise KeyError()
return filters[filter_id]
# trivial yield to make it a generator so d.iC works
yield
defer.returnValue(filters[filter_id])
@defer.inlineCallbacks
def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return filter_id
# trivial yield, see above
yield
defer.returnValue(filter_id)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
......@@ -54,10 +54,12 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
defer.returnValue((200, self.filtering.get_user_filter(
filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
)))
)
defer.returnValue((200, filter))
except KeyError:
raise SynapseError(400, "No such filter")
......@@ -89,7 +91,7 @@ class CreateFilterRestServlet(RestServlet):
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = self.filtering.add_user_filter(
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
definition=content,
)
......
......@@ -53,14 +53,15 @@ class FilteringTestCase(unittest.TestCase):
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def test_filter(self):
filter_id = self.filtering.add_user_filter(
filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
definition={"type": ["m.*"]},
)
self.assertEquals(filter_id, 0)
filter = self.filtering.get_user_filter(
filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment