diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000000000000000000000000000000000000..3ce93cb434f324bb92fe2ac6cca80324f41f62a3 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,17 @@ +sudo: false +language: python +python: 2.7 + +# tell travis to cache ~/.cache/pip +cache: pip + +env: + - TOX_ENV=packaging + - TOX_ENV=pep8 + - TOX_ENV=py27 + +install: + - pip install tox + +script: + - tox -e $TOX_ENV diff --git a/CHANGES.rst b/CHANGES.rst index a1a0624674cb8934476d21a1f5cf9a5ba745fb32..aafd61ab4ae32b2e4fbba26aa7a81937c989ffd3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,68 @@ +Changes in synapse v0.18.5 (2016-12-16) +======================================= + +Bug fixes: + +* Fix federation /backfill returning events it shouldn't (PR #1700) +* Fix crash in url preview (PR #1701) + + +Changes in synapse v0.18.5-rc3 (2016-12-13) +=========================================== + +Features: + +* Add support for E2E for guests (PR #1653) +* Add new API appservice specific public room list (PR #1676) +* Add new room membership APIs (PR #1680) + + +Changes: + +* Enable guest access for private rooms by default (PR #653) +* Limit the number of events that can be created on a given room concurrently + (PR #1620) +* Log the args that we have on UI auth completion (PR #1649) +* Stop generating refresh_tokens (PR #1654) +* Stop putting a time caveat on access tokens (PR #1656) +* Remove unspecced GET endpoints for e2e keys (PR #1694) + + +Bug fixes: + +* Fix handling of 500 and 429's over federation (PR #1650) +* Fix Content-Type header parsing (PR #1660) +* Fix error when previewing sites that include unicode, thanks to kyrias (PR + #1664) +* Fix some cases where we drop read receipts (PR #1678) +* Fix bug where calls to ``/sync`` didn't correctly timeout (PR #1683) +* Fix bug where E2E key query would fail if a single remote host failed (PR + #1686) + + + +Changes in synapse v0.18.5-rc2 (2016-11-24) +=========================================== + +Bug fixes: + +* Don't send old events over federation, fixes bug in -rc1. + +Changes in synapse v0.18.5-rc1 (2016-11-24) +=========================================== + +Features: + +* Implement "event_fields" in filters (PR #1638) + +Changes: + +* Use external ldap auth pacakge (PR #1628) +* Split out federation transaction sending to a worker (PR #1635) +* Fail with a coherent error message if `/sync?filter=` is invalid (PR #1636) +* More efficient notif count queries (PR #1644) + + Changes in synapse v0.18.4 (2016-11-22) ======================================= diff --git a/README.rst b/README.rst index 8925e7c49de4ac58bc969b94e2cd3e079d4a915b..5ffcff22cda1db82f63a6ede54eec87dceb2ef16 100644 --- a/README.rst +++ b/README.rst @@ -120,6 +120,7 @@ Installing prerequisites on Mac OS X:: xcode-select --install sudo easy_install pip sudo pip install virtualenv + brew install pkg-config libffi Installing prerequisites on Raspbian:: @@ -136,6 +137,10 @@ Installing prerequisites on openSUSE:: sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ python-devel libffi-devel libopenssl-devel libjpeg62-devel +Installing prerequisites on OpenBSD:: + doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ + libxslt + To install the synapse homeserver run:: virtualenv -p python2.7 ~/.synapse @@ -369,6 +374,32 @@ Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Mo - Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean`` - Packages: ``pkg install py27-matrix-synapse`` + +OpenBSD +------- + +There is currently no port for OpenBSD. Additionally, OpenBSD's security +settings require a slightly more difficult installation process. + +1) Create a new directory in ``/usr/local`` called ``_synapse``. Also, create a + new user called ``_synapse`` and set that directory as the new user's home. + This is required because, by default, OpenBSD only allows binaries which need + write and execute permissions on the same memory space to be run from + ``/usr/local``. +2) ``su`` to the new ``_synapse`` user and change to their home directory. +3) Create a new virtualenv: ``virtualenv -p python2.7 ~/.synapse`` +4) Source the virtualenv configuration located at + ``/usr/local/_synapse/.synapse/bin/activate``. This is done in ``ksh`` by + using the ``.`` command, rather than ``bash``'s ``source``. +5) Optionally, use ``pip`` to install ``lxml``, which Synapse needs to parse + webpages for their titles. +6) Use ``pip`` to install this repository: ``pip install + https://github.com/matrix-org/synapse/tarball/master`` +7) Optionally, change ``_synapse``'s shell to ``/bin/false`` to reduce the + chance of a compromised Synapse server being used to take over your box. + +After this, you may proceed with the rest of the install directions. + NixOS ----- diff --git a/jenkins-dendron-postgres.sh b/jenkins-dendron-postgres.sh index 70edae43285b133d680c42989a893679a1cd36ad..55ff31fd18f93801d379f4a9111c4d9d6c5be1e9 100755 --- a/jenkins-dendron-postgres.sh +++ b/jenkins-dendron-postgres.sh @@ -22,3 +22,4 @@ export SYNAPSE_CACHE_FACTOR=1 --federation-reader \ --client-reader \ --appservice \ + --federation-sender \ diff --git a/jenkins/prepare_synapse.sh b/jenkins/prepare_synapse.sh index 6c26c5842c16c5d1baa1486506b914a030aa9966..ffcb1cfab90b55546c7f7213c30cb526f31847f1 100755 --- a/jenkins/prepare_synapse.sh +++ b/jenkins/prepare_synapse.sh @@ -15,6 +15,6 @@ tox -e py27 --notest -v TOX_BIN=$TOX_DIR/py27/bin $TOX_BIN/pip install setuptools -python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install -$TOX_BIN/pip install lxml -$TOX_BIN/pip install psycopg2 +{ python synapse/python_dependencies.py + echo lxml psycopg2 +} | xargs $TOX_BIN/pip install diff --git a/setup.py b/setup.py index c0716a1599efa96d7d3cfbc32ee4e8ef4d7acb23..b00c2af367a73d447c0398190af9930c7fb675e8 100755 --- a/setup.py +++ b/setup.py @@ -23,6 +23,45 @@ import sys here = os.path.abspath(os.path.dirname(__file__)) +# Some notes on `setup.py test`: +# +# Once upon a time we used to try to make `setup.py test` run `tox` to run the +# tests. That's a bad idea for three reasons: +# +# 1: `setup.py test` is supposed to find out whether the tests work in the +# *current* environmentt, not whatever tox sets up. +# 2: Empirically, trying to install tox during the test run wasn't working ("No +# module named virtualenv"). +# 3: The tox documentation advises against it[1]. +# +# Even further back in time, we used to use setuptools_trial [2]. That has its +# own set of issues: for instance, it requires installation of Twisted to build +# an sdist (because the recommended mode of usage is to add it to +# `setup_requires`). That in turn means that in order to successfully run tox +# you have to have the python header files installed for whichever version of +# python tox uses (which is python3 on recent ubuntus, for example). +# +# So, for now at least, we stick with what appears to be the convention among +# Twisted projects, and don't attempt to do anything when someone runs +# `setup.py test`; instead we direct people to run `trial` directly if they +# care. +# +# [1]: http://tox.readthedocs.io/en/2.5.0/example/basic.html#integration-with-setup-py-test-command +# [2]: https://pypi.python.org/pypi/setuptools_trial +class TestCommand(Command): + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + print ("""Synapse's tests cannot be run via setup.py. To run them, try: + PYTHONPATH="." trial tests +""") + def read_file(path_segments): """Read a file from the package. Takes a list of strings to join to make the path""" @@ -39,38 +78,6 @@ def exec_file(path_segments): return result -class Tox(Command): - user_options = [('tox-args=', 'a', "Arguments to pass to tox")] - - def initialize_options(self): - self.tox_args = None - - def finalize_options(self): - self.test_args = [] - self.test_suite = True - - def run(self): - #import here, cause outside the eggs aren't loaded - try: - import tox - except ImportError: - try: - self.distribution.fetch_build_eggs("tox") - import tox - except: - raise RuntimeError( - "The tests need 'tox' to run. Please install 'tox'." - ) - import shlex - args = self.tox_args - if args: - args = shlex.split(self.tox_args) - else: - args = [] - errno = tox.cmdline(args=args) - sys.exit(errno) - - version = exec_file(("synapse", "__init__.py"))["__version__"] dependencies = exec_file(("synapse", "python_dependencies.py")) long_description = read_file(("README.rst",)) @@ -86,5 +93,5 @@ setup( zip_safe=False, long_description=long_description, scripts=["synctl"] + glob.glob("scripts/*"), - cmdclass={'test': Tox}, + cmdclass={'test': TestCommand}, ) diff --git a/synapse/__init__.py b/synapse/__init__.py index 432567a110b9ae2a117b8caa0325927607f12ce2..f006e10dc5306ad72e683246bc0033a1aff3ee27 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.18.4" +__version__ = "0.18.5" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 69b3392735e100eb818253d33f9c2a7cea7df679..a99986714d55829a707ed0a4be0a6e729964d25d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -39,6 +39,9 @@ AuthEventTypes = ( EventTypes.ThirdPartyInvite, ) +# guests always get this device id. +GUEST_DEVICE_ID = "guest_device" + class Auth(object): """ @@ -51,17 +54,6 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 - # Docs for these currently lives at - # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst - # In addition, we have type == delete_pusher which grants access only to - # delete pushers. - self._KNOWN_CAVEAT_PREFIXES = set([ - "gen = ", - "guest = ", - "type = ", - "time < ", - "user_id = ", - ]) @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): @@ -685,31 +677,28 @@ class Auth(object): @defer.inlineCallbacks def get_user_by_access_token(self, token, rights="access"): - """ Get a registered user's ID. + """ Validate access token and get user_id from it Args: token (str): The access token to get the user by. + rights (str): The operation being performed; the access token must + allow this. Returns: dict : dict that includes the user and the ID of their access token. Raises: AuthError if no user by that token exists or the token is invalid. """ try: - ret = yield self.get_user_from_macaroon(token, rights) - except AuthError: - # TODO(daniel): Remove this fallback when all existing access tokens - # have been re-issued as macaroons. - if self.hs.config.expire_access_token: - raise - ret = yield self._look_up_user_by_access_token(token) - - defer.returnValue(ret) + macaroon = pymacaroons.Macaroon.deserialize(token) + except Exception: # deserialize can throw more-or-less anything + # doesn't look like a macaroon: treat it as an opaque token which + # must be in the database. + # TODO: it would be nice to get rid of this, but apparently some + # people use access tokens which aren't macaroons + r = yield self._look_up_user_by_access_token(token) + defer.returnValue(r) - @defer.inlineCallbacks - def get_user_from_macaroon(self, macaroon_str, rights="access"): try: - macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - user_id = self.get_user_id_from_macaroon(macaroon) user = UserID.from_string(user_id) @@ -724,11 +713,36 @@ class Auth(object): guest = True if guest: + # Guest access tokens are not stored in the database (there can + # only be one access token per guest, anyway). + # + # In order to prevent guest access tokens being used as regular + # user access tokens (and hence getting around the invalidation + # process), we look up the user id and check that it is indeed + # a guest user. + # + # It would of course be much easier to store guest access + # tokens in the database as well, but that would break existing + # guest tokens. + stored_user = yield self.store.get_user_by_id(user_id) + if not stored_user: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Unknown user_id %s" % user_id, + errcode=Codes.UNKNOWN_TOKEN + ) + if not stored_user["is_guest"]: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Guest access token used for regular user", + errcode=Codes.UNKNOWN_TOKEN + ) ret = { "user": user, "is_guest": True, "token_id": None, - "device_id": None, + # all guests get the same device id + "device_id": GUEST_DEVICE_ID, } elif rights == "delete_pusher": # We don't store these tokens in the database @@ -750,7 +764,7 @@ class Auth(object): # macaroon. They probably should be. # TODO: build the dictionary from the macaroon once the # above are fixed - ret = yield self._look_up_user_by_access_token(macaroon_str) + ret = yield self._look_up_user_by_access_token(token) if ret["user"] != user: logger.error( "Macaroon user (%s) != DB user (%s)", @@ -798,27 +812,38 @@ class Auth(object): Args: macaroon(pymacaroons.Macaroon): The macaroon to validate - type_string(str): The kind of token required (e.g. "access", "refresh", + type_string(str): The kind of token required (e.g. "access", "delete_pusher") verify_expiry(bool): Whether to verify whether the macaroon has expired. - This should really always be True, but no clients currently implement - token refresh, so we can't enforce expiry yet. user_id (str): The user_id required """ v = pymacaroons.Verifier() + + # the verifier runs a test for every caveat on the macaroon, to check + # that it is met for the current request. Each caveat must match at + # least one of the predicates specified by satisfy_exact or + # specify_general. v.satisfy_exact("gen = 1") v.satisfy_exact("type = " + type_string) v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("guest = true") + + # verify_expiry should really always be True, but there exist access + # tokens in the wild which expire when they should not, so we can't + # enforce expiry yet (so we have to allow any caveat starting with + # 'time < ' in access tokens). + # + # On the other hand, short-term login tokens (as used by CAS login, for + # example) have an expiry time which we do want to enforce. + if verify_expiry: v.satisfy_general(self._verify_expiry) else: v.satisfy_general(lambda c: c.startswith("time < ")) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + # access_tokens include a nonce for uniqueness: any value is acceptable + v.satisfy_general(lambda c: c.startswith("nonce = ")) - v = pymacaroons.Verifier() - v.satisfy_general(self._verify_recognizes_caveats) v.verify(macaroon, self.hs.config.macaroon_secret_key) def _verify_expiry(self, caveat): @@ -829,15 +854,6 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - def _verify_recognizes_caveats(self, caveat): - first_space = caveat.find(" ") - if first_space < 0: - return False - second_space = caveat.find(" ", first_space + 1) - if second_space < 0: - return False - return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES - @defer.inlineCallbacks def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 00416468584c11f0db514fe605fca1ec74b4b8e0..921c4577384fdd9ace18711676b894c1cebe2d0b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -39,6 +39,7 @@ class Codes(object): CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" MISSING_PARAM = "M_MISSING_PARAM" + INVALID_PARAM = "M_INVALID_PARAM" TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 3b3ef70750ebf9bacba3c11d1d586333d2c304f6..fb291d7fb9f6d8367430979d95d1cee977b6ebe4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -71,6 +71,21 @@ class Filtering(object): if key in user_filter_json["room"]: self._check_definition(user_filter_json["room"][key]) + if "event_fields" in user_filter_json: + if type(user_filter_json["event_fields"]) != list: + raise SynapseError(400, "event_fields must be a list of strings") + for field in user_filter_json["event_fields"]: + if not isinstance(field, basestring): + raise SynapseError(400, "Event field must be a string") + # Don't allow '\\' in event field filters. This makes matching + # events a lot easier as we can then use a negative lookbehind + # assertion to split '\.' If we allowed \\ then it would + # incorrectly split '\\.' See synapse.events.utils.serialize_event + if r'\\' in field: + raise SynapseError( + 400, r'The escape character \ cannot itself be escaped' + ) + def _check_definition_room_lists(self, definition): """Check that "rooms" and "not_rooms" are lists of room ids if they are present @@ -152,6 +167,7 @@ class FilterCollection(object): self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) + self.event_fields = filter_json.get("event_fields", []) def __repr__(self): return "<FilterCollection %s>" % (json.dumps(self._filter_json),) @@ -186,6 +202,26 @@ class FilterCollection(object): def filter_room_account_data(self, events): return self._room_account_data.filter(self._room_filter.filter(events)) + def blocks_all_presence(self): + return ( + self._presence_filter.filters_all_types() or + self._presence_filter.filters_all_senders() + ) + + def blocks_all_room_ephemeral(self): + return ( + self._room_ephemeral_filter.filters_all_types() or + self._room_ephemeral_filter.filters_all_senders() or + self._room_ephemeral_filter.filters_all_rooms() + ) + + def blocks_all_room_timeline(self): + return ( + self._room_timeline_filter.filters_all_types() or + self._room_timeline_filter.filters_all_senders() or + self._room_timeline_filter.filters_all_rooms() + ) + class Filter(object): def __init__(self, filter_json): @@ -202,6 +238,15 @@ class Filter(object): self.contains_url = self.filter_json.get("contains_url", None) + def filters_all_types(self): + return "*" in self.not_types + + def filters_all_senders(self): + return "*" in self.not_senders + + def filters_all_rooms(self): + return "*" in self.not_rooms + def check(self, event): """Checks whether the filter matches the given event. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py new file mode 100644 index 0000000000000000000000000000000000000000..80ea4c8062fd2be9d238253447123809e2ce302a --- /dev/null +++ b/synapse/app/federation_sender.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.server import HomeServer +from synapse.config._base import ConfigError +from synapse.config.logger import setup_logging +from synapse.config.homeserver import HomeServerConfig +from synapse.crypto import context_factory +from synapse.http.site import SynapseSite +from synapse.federation import send_queue +from synapse.federation.units import Edu +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.storage.engines import create_engine +from synapse.storage.presence import UserPresenceState +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string + +from synapse import events + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc +import ujson as json + +logger = logging.getLogger("synapse.app.appservice") + + +class FederationSenderSlaveStore( + SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, + SlavedRegistrationStore, +): + pass + + +class FederationSenderServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse federation_sender now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + send_handler = FederationSenderHandler(self) + + send_handler.on_start() + + while True: + try: + args = store.stream_positions() + args.update((yield send_handler.stream_positions())) + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + yield send_handler.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(30) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse federation sender", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.federation_sender" + + setup_logging(config.worker_log_config, config.worker_log_file) + + events.USE_FROZEN_DICTS = config.use_frozen_dicts + + database_engine = create_engine(config.database_config) + + if config.send_federation: + sys.stderr.write( + "\nThe send_federation must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``send_federation: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.send_federation = True + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ps = FederationSenderServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ps.setup() + ps.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ps.replicate() + ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-federation-sender", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +class FederationSenderHandler(object): + """Processes the replication stream and forwards the appropriate entries + to the federation sender. + """ + def __init__(self, hs): + self.store = hs.get_datastore() + self.federation_sender = hs.get_federation_sender() + + self._room_serials = {} + self._room_typing = {} + + def on_start(self): + # There may be some events that are persisted but haven't been sent, + # so send them now. + self.federation_sender.notify_new_events( + self.store.get_room_max_stream_ordering() + ) + + @defer.inlineCallbacks + def stream_positions(self): + stream_id = yield self.store.get_federation_out_pos("federation") + defer.returnValue({ + "federation": stream_id, + + # Ack stuff we've "processed", this should only be called from + # one process. + "federation_ack": stream_id, + }) + + @defer.inlineCallbacks + def process_replication(self, result): + # The federation stream contains things that we want to send out, e.g. + # presence, typing, etc. + fed_stream = result.get("federation") + if fed_stream: + latest_id = int(fed_stream["position"]) + + # The federation stream containis a bunch of different types of + # rows that need to be handled differently. We parse the rows, put + # them into the appropriate collection and then send them off. + presence_to_send = {} + keyed_edus = {} + edus = {} + failures = {} + device_destinations = set() + + # Parse the rows in the stream + for row in fed_stream["rows"]: + position, typ, content_js = row + content = json.loads(content_js) + + if typ == send_queue.PRESENCE_TYPE: + destination = content["destination"] + state = UserPresenceState.from_dict(content["state"]) + + presence_to_send.setdefault(destination, []).append(state) + elif typ == send_queue.KEYED_EDU_TYPE: + key = content["key"] + edu = Edu(**content["edu"]) + + keyed_edus.setdefault( + edu.destination, {} + )[(edu.destination, tuple(key))] = edu + elif typ == send_queue.EDU_TYPE: + edu = Edu(**content) + + edus.setdefault(edu.destination, []).append(edu) + elif typ == send_queue.FAILURE_TYPE: + destination = content["destination"] + failure = content["failure"] + + failures.setdefault(destination, []).append(failure) + elif typ == send_queue.DEVICE_MESSAGE_TYPE: + device_destinations.add(content["destination"]) + else: + raise Exception("Unrecognised federation type: %r", typ) + + # We've finished collecting, send everything off + for destination, states in presence_to_send.items(): + self.federation_sender.send_presence(destination, states) + + for destination, edu_map in keyed_edus.items(): + for key, edu in edu_map.items(): + self.federation_sender.send_edu( + edu.destination, edu.edu_type, edu.content, key=key, + ) + + for destination, edu_list in edus.items(): + for edu in edu_list: + self.federation_sender.send_edu( + edu.destination, edu.edu_type, edu.content, key=None, + ) + + for destination, failure_list in failures.items(): + for failure in failure_list: + self.federation_sender.send_failure(destination, failure) + + for destination in device_destinations: + self.federation_sender.send_device_messages(destination) + + # Record where we are in the stream. + yield self.store.update_federation_out_pos( + "federation", latest_id + ) + + # We also need to poke the federation sender when new events happen + event_stream = result.get("events") + if event_stream: + latest_pos = event_stream["position"] + self.federation_sender.notify_new_events(latest_pos) + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 91471f7e890101205112a0d41063486f4314c048..b0106a35977eebed2632f156515f9076505fd5dc 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -89,6 +89,9 @@ class ApplicationService(object): self.namespaces = self._check_namespaces(namespaces) self.id = id + if "|" in self.id: + raise Exception("application service ID cannot contain '|' character") + # .protocols is a publicly visible field if protocols: self.protocols = set(protocols) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index b0eb0c6d9dcf295070e2c22082723ff52a3c102a..6893610e715b9198f4a79ba7a7bb55f3ff6ec1e0 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event from synapse.util.caches.response_cache import ResponseCache +from synapse.types import ThirdPartyInstanceID import logging import urllib @@ -177,6 +178,13 @@ class ApplicationServiceApi(SimpleHttpClient): " valid result", uri) defer.returnValue(None) + for instance in info.get("instances", []): + network_id = instance.get("network_id", None) + if network_id is not None: + instance["instance_id"] = ThirdPartyInstanceID( + service.id, network_id, + ).to_string() + defer.returnValue(info) except Exception as ex: logger.warning("query_3pe_protocol to %s threw exception %s", diff --git a/synapse/config/logger.py b/synapse/config/logger.py index dc68683fbc2edab363054630438cf217d89123d1..ec72c95436a3277bce8e8059b6667740bd3e5596 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -50,6 +50,7 @@ handlers: console: class: logging.StreamHandler formatter: precise + filters: [context] loggers: synapse: diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 1f438d2bb3740fc70cf47603a9adff90bf3b806f..83762d089a0fd8b362b7c586c77d82cde893908d 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -27,17 +27,23 @@ class PasswordAuthProviderConfig(Config): ldap_config = config.get("ldap_config", {}) self.ldap_enabled = ldap_config.get("enabled", False) if self.ldap_enabled: - from synapse.util.ldap_auth_provider import LdapAuthProvider + from ldap_auth_provider import LdapAuthProvider parsed_config = LdapAuthProvider.parse_config(ldap_config) self.password_providers.append((LdapAuthProvider, parsed_config)) providers = config.get("password_providers", []) for provider in providers: - # We need to import the module, and then pick the class out of - # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) - module = importlib.import_module(module) - provider_class = getattr(module, clz) + # This is for backwards compat when the ldap auth provider resided + # in this package. + if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": + from ldap_auth_provider import LdapAuthProvider + provider_class = LdapAuthProvider + else: + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) try: provider_config = provider_class.parse_config(provider["config"]) @@ -50,7 +56,7 @@ class PasswordAuthProviderConfig(Config): def default_config(self, **kwargs): return """\ # password_providers: - # - module: "synapse.util.ldap_auth_provider.LdapAuthProvider" + # - module: "ldap_auth_provider.LdapAuthProvider" # config: # enabled: true # uri: "ldap://ldap.example.com:389" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index cc3f879857a8df2e6fdb10d0b662f120a65370d5..87e500c97a21f19805b190c1702efba035eb3091 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -32,7 +32,6 @@ class RegistrationConfig(Config): ) self.registration_shared_secret = config.get("registration_shared_secret") - self.user_creation_max_duration = int(config["user_creation_max_duration"]) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] @@ -55,11 +54,6 @@ class RegistrationConfig(Config): # secret, even if registration is otherwise disabled. registration_shared_secret: "%(registration_shared_secret)s" - # Sets the expiry for the short term user creation in - # milliseconds. For instance the bellow duration is two weeks - # in milliseconds. - user_creation_max_duration: 1209600000 - # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. # The default number of rounds is 12. diff --git a/synapse/config/server.py b/synapse/config/server.py index ed5417d0c35a17beed737628da92ddfe18e8b116..634d8e6fe5d3ce73ce8d1aae9ed4a389fb040930 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -30,6 +30,11 @@ class ServerConfig(Config): self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") + # Whether to send federation traffic out in this process. This only + # applies to some federation traffic, and so shouldn't be used to + # "disable" federation + self.send_federation = config.get("send_federation", True) + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 0e9fd902af69d8d1ddf3d932c5540b78b3481bf9..5bbaef818708b96fe6c016fc9f44319be58e5bd8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -16,6 +16,17 @@ from synapse.api.constants import EventTypes from . import EventBase +from frozendict import frozendict + +import re + +# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' +# (?<!stuff) matches if the current position in the string is not preceded +# by a match for 'stuff'. +# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as +# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" +SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.') + def prune_event(event): """ Returns a pruned version of the given event, which removes all keys we @@ -97,6 +108,83 @@ def prune_event(event): ) +def _copy_field(src, dst, field): + """Copy the field in 'src' to 'dst'. + + For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] + then dst={"foo":{"bar":5}}. + + Args: + src(dict): The dict to read from. + dst(dict): The dict to modify. + field(list<str>): List of keys to drill down to in 'src'. + """ + if len(field) == 0: # this should be impossible + return + if len(field) == 1: # common case e.g. 'origin_server_ts' + if field[0] in src: + dst[field[0]] = src[field[0]] + return + + # Else is a nested field e.g. 'content.body' + # Pop the last field as that's the key to move across and we need the + # parent dict in order to access the data. Drill down to the right dict. + key_to_move = field.pop(-1) + sub_dict = src + for sub_field in field: # e.g. sub_field => "content" + if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]: + sub_dict = sub_dict[sub_field] + else: + return + + if key_to_move not in sub_dict: + return + + # Insert the key into the output dictionary, creating nested objects + # as required. We couldn't do this any earlier or else we'd need to delete + # the empty objects if the key didn't exist. + sub_out_dict = dst + for sub_field in field: + sub_out_dict = sub_out_dict.setdefault(sub_field, {}) + sub_out_dict[key_to_move] = sub_dict[key_to_move] + + +def only_fields(dictionary, fields): + """Return a new dict with only the fields in 'dictionary' which are present + in 'fields'. + + If there are no event fields specified then all fields are included. + The entries may include '.' charaters to indicate sub-fields. + So ['content.body'] will include the 'body' field of the 'content' object. + A literal '.' character in a field name may be escaped using a '\'. + + Args: + dictionary(dict): The dictionary to read from. + fields(list<str>): A list of fields to copy over. Only shallow refs are + taken. + Returns: + dict: A new dictionary with only the given fields. If fields was empty, + the same dictionary is returned. + """ + if len(fields) == 0: + return dictionary + + # for each field, convert it: + # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]] + split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] + + # for each element of the output array of arrays: + # remove escaping so we can use the right key names. + split_fields[:] = [ + [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields + ] + + output = {} + for field_array in split_fields: + _copy_field(dictionary, output, field_array) + return output + + def format_event_raw(d): return d @@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d): def serialize_event(e, time_now_ms, as_client_event=True, event_format=format_event_for_client_v1, - token_id=None): + token_id=None, only_event_fields=None): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True, d["unsigned"]["transaction_id"] = txn_id if as_client_event: - return event_format(d) - else: - return d + d = event_format(d) + + if only_event_fields: + if (not isinstance(only_event_fields, list) or + not all(isinstance(f, basestring) for f in only_event_fields)): + raise TypeError("only_event_fields must be a list of strings") + d = only_fields(d, only_event_fields) + + return d diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 979fdf2431fbe7c57ab3cb86ff6792cc86fdcc7c..2e32d245ba516112246acabdc9cc4b30d9ef9061 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -17,10 +17,9 @@ """ from .replication import ReplicationLayer -from .transport.client import TransportLayerClient -def initialize_http_replication(homeserver): - transport = TransportLayerClient(homeserver) +def initialize_http_replication(hs): + transport = hs.get_federation_transport_client() - return ReplicationLayer(homeserver, transport) + return ReplicationLayer(hs, transport) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 94e76b1978a688f3d62c4c7ccc8d0177b57d707e..6e23c207ee6179083fa910ea2996d4db7cfea0b9 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,7 +18,6 @@ from twisted.internet import defer from .federation_base import FederationBase from synapse.api.constants import Membership -from .units import Edu from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, @@ -45,10 +44,6 @@ logger = logging.getLogger(__name__) # synapse.federation.federation_client is a silly name metrics = synapse.metrics.get_metrics_for("synapse.federation.client") -sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations") - -sent_edus_counter = metrics.register_counter("sent_edus") - sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) @@ -92,63 +87,6 @@ class FederationClient(FederationBase): self._get_pdu_cache.start() - @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - sent_pdus_destination_dist.inc_by(len(destinations)) - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - def send_presence(self, destination, states): - if destination != self.server_name: - self._transaction_queue.enqueue_presence(destination, states) - - @log_function - def send_edu(self, destination, edu_type, content, key=None): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - sent_edus_counter.inc() - - self._transaction_queue.enqueue_edu(edu, key=key) - - @log_function - def send_device_messages(self, destination): - """Sends the device messages in the local database to the remote - destination""" - self._transaction_queue.enqueue_device_messages(destination) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False): @@ -717,12 +655,15 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") def get_public_rooms(self, destination, limit=None, since_token=None, - search_filter=None): + search_filter=None, include_all_networks=False, + third_party_instance_id=None): if destination == self.server_name: return return self.transport_layer.get_public_rooms( - destination, limit, since_token, search_filter + destination, limit, since_token, search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, ) @defer.inlineCallbacks diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index ea66a5dcbc500277b772b232c5a83b393ca4fdc5..62d865ec4bee22dc960109400fa31cafc5c1e31a 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -20,8 +20,6 @@ a given transport. from .federation_client import FederationClient from .federation_server import FederationServer -from .transaction_queue import TransactionQueue - from .persistence import TransactionActions import logging @@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer): self._clock = hs.get_clock() self.transaction_actions = TransactionActions(self.store) - self._transaction_queue = TransactionQueue(hs, transport_layer) - - self._order = 0 self.hs = hs diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9f7a86f0aab1a22a0af5785ba2738672bc4fd2 --- /dev/null +++ b/synapse/federation/send_queue.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A federation sender that forwards things to be sent across replication to +a worker process. + +It assumes there is a single worker process feeding off of it. + +Each row in the replication stream consists of a type and some json, where the +types indicate whether they are presence, or edus, etc. + +Ephemeral or non-event data are queued up in-memory. When the worker requests +updates since a particular point, all in-memory data since before that point is +dropped. We also expire things in the queue after 5 minutes, to ensure that a +dead worker doesn't cause the queues to grow limitlessly. + +Events are replicated via a separate events stream. +""" + +from .units import Edu + +from synapse.util.metrics import Measure +import synapse.metrics + +from blist import sorteddict +import ujson + + +metrics = synapse.metrics.get_metrics_for(__name__) + + +PRESENCE_TYPE = "p" +KEYED_EDU_TYPE = "k" +EDU_TYPE = "e" +FAILURE_TYPE = "f" +DEVICE_MESSAGE_TYPE = "d" + + +class FederationRemoteSendQueue(object): + """A drop in replacement for TransactionQueue""" + + def __init__(self, hs): + self.server_name = hs.hostname + self.clock = hs.get_clock() + + self.presence_map = {} + self.presence_changed = sorteddict() + + self.keyed_edu = {} + self.keyed_edu_changed = sorteddict() + + self.edus = sorteddict() + + self.failures = sorteddict() + + self.device_messages = sorteddict() + + self.pos = 1 + self.pos_time = sorteddict() + + # EVERYTHING IS SAD. In particular, python only makes new scopes when + # we make a new function, so we need to make a new function so the inner + # lambda binds to the queue rather than to the name of the queue which + # changes. ARGH. + def register(name, queue): + metrics.register_callback( + queue_name + "_size", + lambda: len(queue), + ) + + for queue_name in [ + "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", + "edus", "failures", "device_messages", "pos_time", + ]: + register(queue_name, getattr(self, queue_name)) + + self.clock.looping_call(self._clear_queue, 30 * 1000) + + def _next_pos(self): + pos = self.pos + self.pos += 1 + self.pos_time[self.clock.time_msec()] = pos + return pos + + def _clear_queue(self): + """Clear the queues for anything older than N minutes""" + + FIVE_MINUTES_AGO = 5 * 60 * 1000 + now = self.clock.time_msec() + + keys = self.pos_time.keys() + time = keys.bisect_left(now - FIVE_MINUTES_AGO) + if not keys[:time]: + return + + position_to_delete = max(keys[:time]) + for key in keys[:time]: + del self.pos_time[key] + + self._clear_queue_before_pos(position_to_delete) + + def _clear_queue_before_pos(self, position_to_delete): + """Clear all the queues from before a given position""" + with Measure(self.clock, "send_queue._clear"): + # Delete things out of presence maps + keys = self.presence_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.presence_changed[key] + + user_ids = set( + user_id for uids in self.presence_changed.values() for _, user_id in uids + ) + + to_del = [ + user_id for user_id in self.presence_map if user_id not in user_ids + ] + for user_id in to_del: + del self.presence_map[user_id] + + # Delete things out of keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.keyed_edu_changed[key] + + live_keys = set() + for edu_key in self.keyed_edu_changed.values(): + live_keys.add(edu_key) + + to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys] + for edu_key in to_del: + del self.keyed_edu[edu_key] + + # Delete things out of edu map + keys = self.edus.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.edus[key] + + # Delete things out of failure map + keys = self.failures.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.failures[key] + + # Delete things out of device map + keys = self.device_messages.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.device_messages[key] + + def notify_new_events(self, current_id): + """As per TransactionQueue""" + # We don't need to replicate this as it gets sent down a different + # stream. + pass + + def send_edu(self, destination, edu_type, content, key=None): + """As per TransactionQueue""" + pos = self._next_pos() + + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + if key: + assert isinstance(key, tuple) + self.keyed_edu[(destination, key)] = edu + self.keyed_edu_changed[pos] = (destination, key) + else: + self.edus[pos] = edu + + def send_presence(self, destination, states): + """As per TransactionQueue""" + pos = self._next_pos() + + self.presence_map.update({ + state.user_id: state + for state in states + }) + + self.presence_changed[pos] = [ + (destination, state.user_id) for state in states + ] + + def send_failure(self, failure, destination): + """As per TransactionQueue""" + pos = self._next_pos() + + self.failures[pos] = (destination, str(failure)) + + def send_device_messages(self, destination): + """As per TransactionQueue""" + pos = self._next_pos() + self.device_messages[pos] = destination + + def get_current_token(self): + return self.pos - 1 + + def get_replication_rows(self, token, limit, federation_ack=None): + """ + Args: + token (int) + limit (int) + federation_ack (int): Optional. The position where the worker is + explicitly acknowledged it has handled. Allows us to drop + data from before that point + """ + # TODO: Handle limit. + + # To handle restarts where we wrap around + if token > self.pos: + token = -1 + + rows = [] + + # There should be only one reader, so lets delete everything its + # acknowledged its seen. + if federation_ack: + self._clear_queue_before_pos(federation_ack) + + # Fetch changed presence + keys = self.presence_changed.keys() + i = keys.bisect_right(token) + dest_user_ids = set( + (pos, dest_user_id) + for pos in keys[i:] + for dest_user_id in self.presence_changed[pos] + ) + + for (key, (dest, user_id)) in dest_user_ids: + rows.append((key, PRESENCE_TYPE, ujson.dumps({ + "destination": dest, + "state": self.presence_map[user_id].as_dict(), + }))) + + # Fetch changes keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_right(token) + keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:]) + + for (pos, (destination, edu_key)) in keyed_edus: + rows.append( + (pos, KEYED_EDU_TYPE, ujson.dumps({ + "key": edu_key, + "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), + })) + ) + + # Fetch changed edus + keys = self.edus.keys() + i = keys.bisect_right(token) + edus = set((k, self.edus[k]) for k in keys[i:]) + + for (pos, edu) in edus: + rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + + # Fetch changed failures + keys = self.failures.keys() + i = keys.bisect_right(token) + failures = set((k, self.failures[k]) for k in keys[i:]) + + for (pos, (destination, failure)) in failures: + rows.append((pos, FAILURE_TYPE, ujson.dumps({ + "destination": destination, + "failure": failure, + }))) + + # Fetch changed device messages + keys = self.device_messages.keys() + i = keys.bisect_right(token) + device_messages = set((k, self.device_messages[k]) for k in keys[i:]) + + for (pos, destination) in device_messages: + rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ + "destination": destination, + }))) + + # Sort rows based on pos + rows.sort() + + return rows diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f8ca93e4c37e8d1b9754fd609e29e3c9d72113bd..51b656d74a19bc0b8be63f25708a5bf24186306f 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -19,6 +19,7 @@ from twisted.internet import defer from .persistence import TransactionActions from .units import Transaction, Edu +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn @@ -26,6 +27,7 @@ from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) from synapse.util.metrics import measure_func +from synapse.types import get_domain_from_id from synapse.handlers.presence import format_user_presence_state import synapse.metrics @@ -36,6 +38,12 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) +client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client") +sent_pdus_destination_dist = client_metrics.register_distribution( + "sent_pdu_destinations" +) +sent_edus_counter = client_metrics.register_counter("sent_edus") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -44,13 +52,14 @@ class TransactionQueue(object): It batches pending PDUs into single transactions. """ - def __init__(self, hs, transport_layer): + def __init__(self, hs): self.server_name = hs.hostname self.store = hs.get_datastore() + self.state = hs.get_state_handler() self.transaction_actions = TransactionActions(self.store) - self.transport_layer = transport_layer + self.transport_layer = hs.get_federation_transport_client() self.clock = hs.get_clock() @@ -95,6 +104,11 @@ class TransactionQueue(object): # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) + self._order = 1 + + self._is_processing = False + self._last_poked_id = -1 + def can_send_to(self, destination): """Can we send messages to the given server? @@ -115,11 +129,61 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - def enqueue_pdu(self, pdu, destinations, order): + @defer.inlineCallbacks + def notify_new_events(self, current_id): + """This gets called when we have some new events we might want to + send out to other servers. + """ + self._last_poked_id = max(current_id, self._last_poked_id) + + if self._is_processing: + return + + try: + self._is_processing = True + while True: + last_token = yield self.store.get_federation_out_pos("events") + next_token, events = yield self.store.get_all_new_events_stream( + last_token, self._last_poked_id, limit=20, + ) + + logger.debug("Handling %s -> %s", last_token, next_token) + + if not events and next_token >= self._last_poked_id: + break + + for event in events: + users_in_room = yield self.state.get_current_user_in_room( + event.room_id, latest_event_ids=[event.event_id], + ) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + ) + + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(event.state_key)) + + logger.debug("Sending %s to %r", event, destinations) + + self._send_pdu(event, destinations) + + yield self.store.update_federation_out_pos( + "events", next_token + ) + + finally: + self._is_processing = False + + def _send_pdu(self, pdu, destinations): # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. + order = self._order + self._order += 1 + destinations = set(destinations) destinations = set( dest for dest in destinations if self.can_send_to(dest) @@ -130,6 +194,8 @@ class TransactionQueue(object): if not destinations: return + sent_pdus_destination_dist.inc_by(len(destinations)) + for destination in destinations: self.pending_pdus_by_dest.setdefault(destination, []).append( (pdu, order) @@ -139,7 +205,10 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_presence(self, destination, states): + def send_presence(self, destination, states): + if not self.can_send_to(destination): + return + self.pending_presence_by_dest.setdefault(destination, {}).update({ state.user_id: state for state in states }) @@ -148,12 +217,19 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_edu(self, edu, key=None): - destination = edu.destination + def send_edu(self, destination, edu_type, content, key=None): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) if not self.can_send_to(destination): return + sent_edus_counter.inc() + if key: self.pending_edus_keyed_by_dest.setdefault( destination, {} @@ -165,7 +241,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_failure(self, failure, destination): + def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": return @@ -180,7 +256,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_device_messages(self, destination): + def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -191,6 +267,9 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) + def get_current_token(self): + return 0 + @defer.inlineCallbacks def _attempt_new_transaction(self, destination): # list of (pending_pdu, deferred, order) @@ -383,6 +462,13 @@ class TransactionQueue(object): code = e.code response = e.response + if e.code == 429 or 500 <= e.code: + logger.info( + "TX [%s] {%s} got %d response", + destination, txn_id, code + ) + raise e + logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index db45c7826ceb677d23c331e48b10a8a481d2c305..491cdc29e172b4644d7fe00ae32b2e0b3b968529 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -249,10 +249,15 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_public_rooms(self, remote_server, limit, since_token, - search_filter=None): + search_filter=None, include_all_networks=False, + third_party_instance_id=None): path = PREFIX + "/publicRooms" - args = {} + args = { + "include_all_networks": "true" if include_all_networks else "false", + } + if third_party_instance_id: + args["third_party_instance_id"] = third_party_instance_id, if limit: args["limit"] = [str(limit)] if since_token: diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fec337be645409ef857c4f58ead8e31b1e222bc8..159dbd174718e57e871667ec6d813634138fd5d2 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, + parse_boolean_from_args, ) from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string +from synapse.types import ThirdPartyInstanceID import functools import logging @@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet): def on_GET(self, origin, content, query): limit = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) + include_all_networks = parse_boolean_from_args( + query, "include_all_networks", False + ) + third_party_instance_id = parse_string_from_args( + query, "third_party_instance_id", None + ) + + if include_all_networks: + network_tuple = None + elif third_party_instance_id: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + else: + network_tuple = ThirdPartyInstanceID(None, None) + data = yield self.room_list_handler.get_local_public_room_list( - limit, since_token + limit, since_token, + network_tuple=network_tuple ) defer.returnValue((200, data)) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 63d05f25310bd55dbe7f5f8d453df2e4eafe3f24..5ad408f5494b88f7aa24b87fe923bdf0ab5f3f55 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -24,7 +24,6 @@ from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler from .identity import IdentityHandler -from .receipts import ReceiptsHandler from .search import SearchHandler @@ -56,7 +55,6 @@ class Handlers(object): self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) - self.receipts_handler = ReceiptsHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3851b3588901f9f8c222efd33d6afded57ea35f8..3b146f09d69d82d37672362211b6645b6d919181 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -61,6 +61,8 @@ class AuthHandler(BaseHandler): for module, config in hs.config.password_providers ] + logger.info("Extra password_providers: %r", self.password_providers) + self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() @@ -160,7 +162,15 @@ class AuthHandler(BaseHandler): for f in flows: if len(set(f) - set(creds.keys())) == 0: - logger.info("Auth completed with creds: %r", creds) + # it's very useful to know what args are stored, but this can + # include the password in the case of registering, so only log + # the keys (confusingly, clientdict may contain a password + # param, creds is just what the user authed as for UI auth + # and is not sensitive). + logger.info( + "Auth completed with creds: %r. Client dict has keys: %r", + creds, clientdict.keys() + ) defer.returnValue((True, creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) @@ -378,12 +388,10 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ - Gets login tuple for the user with the given user ID. - - Creates a new access/refresh token for the user. + Creates a new access token for the user with the given user ID. The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. @@ -398,16 +406,13 @@ class AuthHandler(BaseHandler): initial_display_name (str): display name to associate with the device if it needs re-registering Returns: - A tuple of: The access token for the user's session. - The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) - refresh_token = yield self.issue_refresh_token(user_id, device_id) # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -418,7 +423,7 @@ class AuthHandler(BaseHandler): user_id, device_id, initial_display_name ) - defer.returnValue((access_token, refresh_token)) + defer.returnValue(access_token) @defer.inlineCallbacks def check_user_exists(self, user_id): @@ -529,35 +534,19 @@ class AuthHandler(BaseHandler): device_id) defer.returnValue(access_token) - @defer.inlineCallbacks - def issue_refresh_token(self, user_id, device_id=None): - refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token, - device_id) - defer.returnValue(refresh_token) - - def generate_access_token(self, user_id, extra_caveats=None, - duration_in_ms=(60 * 60 * 1000)): + def generate_access_token(self, user_id, extra_caveats=None): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_refresh_token(self, user_id): - m = self._generate_base_macaroon(user_id) - m.add_first_party_caveat("type = refresh") - # Important to add a nonce, because otherwise every refresh token for a - # user will be the same. - m.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - return m.serialize() - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index c5368e5df2660a7ead7f0b94e333cb3a015149f1..f7fad15c620f7dfe7cd3ef5776c35594cf56ee75 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -34,9 +34,9 @@ class DeviceMessageHandler(object): self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_sender() - self.federation.register_edu_handler( + hs.get_replication_layer().register_edu_handler( "m.direct_to_device", self.on_direct_to_device_edu ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index c00274afc3c9f4f1be98a5aa0051056a05914c73..1b5317edf5097cbac7e17b626ccebc1856f0aa94 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -339,3 +339,22 @@ class DirectoryHandler(BaseHandler): yield self.auth.check_can_change_room_list(room_id, requester.user) yield self.store.set_room_is_public(room_id, visibility == "public") + + @defer.inlineCallbacks + def edit_published_appservice_room_list(self, appservice_id, network_id, + room_id, visibility): + """Add or remove a room from the appservice/network specific public + room list. + + Args: + appservice_id (str): ID of the appservice that owns the list + network_id (str): The ID of the network the list is associated with + room_id (str) + visibility (str): either "public" or "private" + """ + if visibility not in ["public", "private"]: + raise SynapseError(400, "Invalid visibility setting") + + yield self.store.set_room_is_public_appservice( + room_id, appservice_id, network_id, visibility == "public" + ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index fd11935b401ce4309b642b16200b2bfeba993afc..b63a660c065a9a4301922e45d4347635fd6d91ab 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -111,6 +111,11 @@ class E2eKeysHandler(object): failures[destination] = { "status": 503, "message": "Not ready for retry", } + except Exception as e: + # include ConnectionRefused and other errors + failures[destination] = { + "status": 503, "message": e.message + } yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) @@ -222,6 +227,11 @@ class E2eKeysHandler(object): failures[destination] = { "status": 503, "message": "Not ready for retry", } + except Exception as e: + # include ConnectionRefused and other errors + failures[destination] = { + "status": 503, "message": e.message + } yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(claim_client_keys)(destination) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2d801bad476e280423cfded3aa8ce1db4ef2aebd..1d07e4d02b22183c2f90babbb9629908633945c8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -80,22 +80,6 @@ class FederationHandler(BaseHandler): # When joining a room we need to queue any events for that room up self.room_queues = {} - def handle_new_event(self, event, destinations): - """ Takes in an event from the client to server side, that has already - been authed and handled by the state module, and sends it to any - remote home servers that may be interested. - - Args: - event: The event to send - destinations: A list of destinations to send it to - - Returns: - Deferred: Resolved when it has successfully been queued for - processing. - """ - - return self.replication_layer.send_pdu(event, destinations) - @log_function @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): @@ -268,9 +252,12 @@ class FederationHandler(BaseHandler): except: return False + # Parses mapping `event_id -> (type, state_key) -> state event_id` + # to get all state ids that we're interested in. event_map = yield self.store.get_events([ - e_id for key_to_eid in event_to_state_ids.values() - for key, e_id in key_to_eid + e_id + for key_to_eid in event_to_state_ids.values() + for key, e_id in key_to_eid.items() if key[0] != EventTypes.Member or check_match(key[1]) ]) @@ -830,25 +817,6 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - new_pdu = event - - users_in_room = yield self.store.get_joined_users_from_context(event, context) - - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - if not self.hs.is_mine_id(user_id) - ) - - destinations.discard(origin) - - logger.debug( - "on_send_join_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids @@ -1055,24 +1023,6 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - new_pdu = event - - users_in_room = yield self.store.get_joined_users_from_context(event, context) - - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - if not self.hs.is_mine_id(user_id) - ) - destinations.discard(origin) - - logger.debug( - "on_send_leave_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - defer.returnValue(None) @defer.inlineCallbacks diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index fbfa5a0281806906e93761054d8325584f1dade2..e0ade4c164d75deeb894fdcbdb141d7f4bb07c7f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -372,11 +372,12 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_receipts(): - receipts_handler = self.hs.get_handlers().receipts_handler - receipts = yield receipts_handler.get_receipts_for_room( + receipts = yield self.store.get_linearized_receipts_for_room( room_id, - now_token.receipt_key + to_key=now_token.receipt_key, ) + if not receipts: + receipts = [] defer.returnValue(receipts) presence, receipts, (messages, token) = yield defer.gatherResults( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 81df45177a1588093b03c2e19a9ef5f93732bebd..7a57a69bd33dff9a785129a0f22816d808216524 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,9 +22,9 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.push.action_generator import ActionGenerator from synapse.types import ( - UserID, RoomAlias, RoomStreamToken, get_domain_from_id + UserID, RoomAlias, RoomStreamToken, ) -from synapse.util.async import run_on_reactor, ReadWriteLock +from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.logcontext import preserve_fn from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -50,6 +50,10 @@ class MessageHandler(BaseHandler): self.pagination_lock = ReadWriteLock() + # We arbitrarily limit concurrent event creation for a room to 5. + # This is to stop us from diverging history *too* much. + self.limiter = Limiter(max_count=5) + @defer.inlineCallbacks def purge_history(self, room_id, event_id): event = yield self.store.get_event(event_id) @@ -191,36 +195,38 @@ class MessageHandler(BaseHandler): """ builder = self.event_builder_factory.new(event_dict) - self.validator.validate_new(builder) - - if builder.type == EventTypes.Member: - membership = builder.content.get("membership", None) - target = UserID.from_string(builder.state_key) - - if membership in {Membership.JOIN, Membership.INVITE}: - # If event doesn't include a display name, add one. - profile = self.hs.get_handlers().profile_handler - content = builder.content + with (yield self.limiter.queue(builder.room_id)): + self.validator.validate_new(builder) + + if builder.type == EventTypes.Member: + membership = builder.content.get("membership", None) + target = UserID.from_string(builder.state_key) + + if membership in {Membership.JOIN, Membership.INVITE}: + # If event doesn't include a display name, add one. + profile = self.hs.get_handlers().profile_handler + content = builder.content + + try: + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + except Exception as e: + logger.info( + "Failed to get profile information for %r: %s", + target, e + ) - try: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) - except Exception as e: - logger.info( - "Failed to get profile information for %r: %s", - target, e - ) + if token_id is not None: + builder.internal_metadata.token_id = token_id - if token_id is not None: - builder.internal_metadata.token_id = token_id + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id - if txn_id is not None: - builder.internal_metadata.txn_id = txn_id + event, context = yield self._create_new_client_event( + builder=builder, + prev_event_ids=prev_event_ids, + ) - event, context = yield self._create_new_client_event( - builder=builder, - prev_event_ids=prev_event_ids, - ) defer.returnValue((event, context)) @defer.inlineCallbacks @@ -599,13 +605,6 @@ class MessageHandler(BaseHandler): event_stream_id, max_stream_id ) - users_in_room = yield self.store.get_joined_users_from_context(event, context) - - destinations = [ - get_domain_from_id(user_id) for user_id in users_in_room - if not self.hs.is_mine_id(user_id) - ] - @defer.inlineCallbacks def _notify(): yield run_on_reactor() @@ -618,7 +617,3 @@ class MessageHandler(BaseHandler): # If invite, remove room_state from unsigned before sending. event.unsigned.pop("invite_room_state", None) - - preserve_fn(federation_handler.handle_new_event)( - event, destinations=destinations, - ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b047ae2250dad67891660312773ceda83b473581..1b89dc6274dd404045b37133cc137fe86bcc6bc1 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -91,28 +91,29 @@ class PresenceHandler(object): self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.federation = hs.get_replication_layer() + self.replication = hs.get_replication_layer() + self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence", self.incoming_presence ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( observed_user=UserID.from_string(content["observed_user"]), diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e536a909d01b6205dbdcceb73a2d959d6267b170..50aa5139356b7cfc9fb736598fdac0cb3d40c35f 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -33,8 +33,8 @@ class ReceiptsHandler(BaseHandler): self.server_name = hs.config.server_name self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_replication_layer() - self.federation.register_edu_handler( + self.federation = hs.get_federation_sender() + hs.get_replication_layer().register_edu_handler( "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() @@ -100,7 +100,7 @@ class ReceiptsHandler(BaseHandler): if not res: # res will be None if this read receipt is 'old' - defer.returnValue(False) + continue stream_id, max_persisted_id = res @@ -109,6 +109,10 @@ class ReceiptsHandler(BaseHandler): if max_batch_id is None or max_persisted_id > max_batch_id: max_batch_id = max_persisted_id + if min_batch_id is None: + # no new receipts + defer.returnValue(False) + affected_room_ids = list(set([r["room_id"] for r in receipts])) with PreserveLoggingContext(): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7e119f13b1a1f8b27d9a0fde7244041338e536e3..286f0cef0a02ea0f289886a72eab6032376744a8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -81,7 +81,7 @@ class RegistrationHandler(BaseHandler): "User ID already taken.", errcode=Codes.USER_IN_USE, ) - user_data = yield self.auth.get_user_from_macaroon(guest_access_token) + user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( 403, @@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token( - user_id, None, duration_in_ms) + token = self.auth_handler().generate_access_token(user_id) if need_register: yield self.store.register( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 59e4d1cd15e0ada4beda073310fa7d460f11a57d..5f18007e90e4119944d4b1d06fc1bd4d39389baa 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -44,16 +44,19 @@ class RoomCreationHandler(BaseHandler): "join_rules": JoinRules.INVITE, "history_visibility": "shared", "original_invitees_have_ops": False, + "guest_can_join": True, }, RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": "shared", "original_invitees_have_ops": True, + "guest_can_join": True, }, RoomCreationPreset.PUBLIC_CHAT: { "join_rules": JoinRules.PUBLIC, "history_visibility": "shared", "original_invitees_have_ops": False, + "guest_can_join": False, }, } @@ -336,6 +339,13 @@ class RoomCreationHandler(BaseHandler): content={"history_visibility": config["history_visibility"]} ) + if config["guest_can_join"]: + if (EventTypes.GuestAccess, '') not in initial_state: + yield send( + etype=EventTypes.GuestAccess, + content={"guest_access": "can_join"} + ) + for (etype, state_key), content in initial_state.items(): yield send( etype=etype, diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index b04aea0110cd18e5608d7aeaf50e94e8896a9088..667223df0c90ed8edcddf7877a6aed101d52de62 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -22,6 +22,7 @@ from synapse.api.constants import ( ) from synapse.util.async import concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.types import ThirdPartyInstanceID from collections import namedtuple from unpaddedbase64 import encode_base64, decode_base64 @@ -34,6 +35,10 @@ logger = logging.getLogger(__name__) REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 +# This is used to indicate we should only return rooms published to the main list. +EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) + + class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) @@ -41,22 +46,43 @@ class RoomListHandler(BaseHandler): self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) def get_local_public_room_list(self, limit=None, since_token=None, - search_filter=None): - if search_filter: - # We explicitly don't bother caching searches. - return self._get_public_room_list(limit, since_token, search_filter) + search_filter=None, + network_tuple=EMTPY_THIRD_PARTY_ID,): + """Generate a local public room list. + + There are multiple different lists: the main one plus one per third + party network. A client can ask for a specific list or to return all. + + Args: + limit (int) + since_token (str) + search_filter (dict) + network_tuple (ThirdPartyInstanceID): Which public list to use. + This can be (None, None) to indicate the main list, or a particular + appservice and network id to use an appservice specific one. + Setting to None returns all public rooms across all lists. + """ + if search_filter or (network_tuple and network_tuple.appservice_id is not None): + # We explicitly don't bother caching searches or requests for + # appservice specific lists. + return self._get_public_room_list( + limit, since_token, search_filter, network_tuple=network_tuple, + ) result = self.response_cache.get((limit, since_token)) if not result: result = self.response_cache.set( (limit, since_token), - self._get_public_room_list(limit, since_token) + self._get_public_room_list( + limit, since_token, network_tuple=network_tuple + ) ) return result @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, - search_filter=None): + search_filter=None, + network_tuple=EMTPY_THIRD_PARTY_ID,): if since_token and since_token != "END": since_token = RoomListNextBatch.from_token(since_token) else: @@ -73,14 +99,15 @@ class RoomListHandler(BaseHandler): current_public_id = yield self.store.get_current_public_room_stream_id() public_room_stream_id = since_token.public_room_stream_id newly_visible, newly_unpublished = yield self.store.get_public_room_changes( - public_room_stream_id, current_public_id + public_room_stream_id, current_public_id, + network_tuple=network_tuple, ) else: stream_token = yield self.store.get_room_max_stream_ordering() public_room_stream_id = yield self.store.get_current_public_room_stream_id() room_ids = yield self.store.get_public_room_ids_at_stream_id( - public_room_stream_id + public_room_stream_id, network_tuple=network_tuple, ) # We want to return rooms in a particular order: the number of joined @@ -311,7 +338,8 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def get_remote_public_room_list(self, server_name, limit=None, since_token=None, - search_filter=None): + search_filter=None, include_all_networks=False, + third_party_instance_id=None,): if search_filter: # We currently don't support searching across federation, so we have # to do it manually without pagination @@ -320,6 +348,8 @@ class RoomListHandler(BaseHandler): res = yield self._get_remote_list_cached( server_name, limit=limit, since_token=since_token, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, ) if search_filter: @@ -332,22 +362,30 @@ class RoomListHandler(BaseHandler): defer.returnValue(res) def _get_remote_list_cached(self, server_name, limit=None, since_token=None, - search_filter=None): + search_filter=None, include_all_networks=False, + third_party_instance_id=None,): repl_layer = self.hs.get_replication_layer() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( server_name, limit=limit, since_token=since_token, - search_filter=search_filter, + search_filter=search_filter, include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, ) - result = self.remote_response_cache.get((server_name, limit, since_token)) + key = ( + server_name, limit, since_token, include_all_networks, + third_party_instance_id, + ) + result = self.remote_response_cache.get(key) if not result: result = self.remote_response_cache.set( - (server_name, limit, since_token), + key, repl_layer.get_public_rooms( server_name, limit=limit, since_token=since_token, search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, ) ) return result diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1f910ff8149bb4430b0dcafc4e3396520e8fb99b..c880f616858d8a06d3dc65d98145336c31835364 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -277,6 +277,7 @@ class SyncHandler(object): """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() + block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True @@ -293,7 +294,7 @@ class SyncHandler(object): else: recents = [] - if not limited: + if not limited or block_all_timeline: defer.returnValue(TimelineBatch( events=recents, prev_batch=now_token, @@ -509,6 +510,7 @@ class SyncHandler(object): Returns: Deferred(SyncResult) """ + logger.info("Calculating sync response for %r", sync_config.user) # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability @@ -531,9 +533,14 @@ class SyncHandler(object): ) newly_joined_rooms, newly_joined_users = res - yield self._generate_sync_entry_for_presence( - sync_result_builder, newly_joined_rooms, newly_joined_users + block_all_presence_data = ( + since_token is None and + sync_config.filter_collection.blocks_all_presence() ) + if not block_all_presence_data: + yield self._generate_sync_entry_for_presence( + sync_result_builder, newly_joined_rooms, newly_joined_users + ) yield self._generate_sync_entry_for_to_device(sync_result_builder) @@ -569,16 +576,20 @@ class SyncHandler(object): # We only delete messages when a new message comes in, but that's # fine so long as we delete them at some point. - logger.debug("Deleting messages up to %d", since_stream_id) - yield self.store.delete_messages_for_device( + deleted = yield self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) + logger.info("Deleted %d to-device messages up to %d", + deleted, since_stream_id) - logger.debug("Getting messages up to %d", now_token.to_device_key) messages, stream_id = yield self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) - logger.debug("Got messages up to %d: %r", stream_id, messages) + + logger.info( + "Returning %d to-device messages between %d and %d (current token: %d)", + len(messages), since_stream_id, stream_id, now_token.to_device_key + ) sync_result_builder.now_token = now_token.copy_and_replace( "to_device_key", stream_id ) @@ -709,13 +720,20 @@ class SyncHandler(object): `(newly_joined_rooms, newly_joined_users)` """ user_id = sync_result_builder.sync_config.user.to_string() - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_result_builder.sync_config, - now_token=sync_result_builder.now_token, - since_token=sync_result_builder.since_token, + block_all_room_ephemeral = ( + sync_result_builder.since_token is None and + sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) - sync_result_builder.now_token = now_token + + if block_all_room_ephemeral: + ephemeral_by_room = {} + else: + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_result_builder.sync_config, + now_token=sync_result_builder.now_token, + since_token=sync_result_builder.since_token, + ) + sync_result_builder.now_token = now_token ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 27ee715ff013ec799f5935e775d962d97d223610..0eea7f8f9c29f1d39d6387691c4ccb01862018e0 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -55,9 +55,9 @@ class TypingHandler(object): self.clock = hs.get_clock() self.wheel_timer = WheelTimer(bucket_size=5000) - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_sender() - self.federation.register_edu_handler("m.typing", self._recv_edu) + hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index d0556ae347276897a46edcb0465b52d9c67de444..d5970c05a8027180be1bdf0af1aae97992f59090 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from signedjson.sign import sign_json +import cgi import simplejson as json import logging import random @@ -292,12 +293,7 @@ class MatrixFederationHttpClient(object): if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) defer.returnValue(json.loads(body)) @@ -342,12 +338,7 @@ class MatrixFederationHttpClient(object): if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) @@ -400,12 +391,7 @@ class MatrixFederationHttpClient(object): if 200 <= response.code < 300: # We need to update the transactions table to say it was sent? - c_type = response.headers.getRawHeaders("Content-Type") - - if "application/json" not in c_type: - raise RuntimeError( - "Content-Type not application/json" - ) + check_content_type_is_json(response.headers) body = yield preserve_context_over_fn(readBody, response) @@ -525,3 +511,29 @@ def _flatten_response_never_received(e): ) else: return "%s: %s" % (type(e).__name__, e.message,) + + +def check_content_type_is_json(headers): + """ + Check that a set of HTTP headers have a Content-Type header, and that it + is application/json. + + Args: + headers (twisted.web.http_headers.Headers): headers to check + + Raises: + RuntimeError if the + + """ + c_type = headers.getRawHeaders("Content-Type") + if c_type is None: + raise RuntimeError( + "No Content-Type header" + ) + + c_type = c_type[0] # only the first header + val, options = cgi.parse_header(c_type) + if val != "application/json": + raise RuntimeError( + "Content-Type not application/json: was '%s'" % c_type + ) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 934638623893277ed5f1c512238cfd3f4c0614e6..8c22d6f00f41735ebf74914dfe9648207cf837f2 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False): parameter is present and not one of "true" or "false". """ - if name in request.args: + return parse_boolean_from_args(request.args, name, default, required) + + +def parse_boolean_from_args(args, name, default=None, required=False): + if name in args: try: return { "true": True, "false": False, - }[request.args[name][0]] + }[args[name][0]] except: message = ( "Boolean query parameter %r must be one of" diff --git a/synapse/notifier.py b/synapse/notifier.py index 48653ae843c277435540a904d96cbe1099ad8930..acbd4bb5ae1199231ece55ebd2d0d740637d6895 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError +from synapse.util import DeferredTimedOutError from synapse.util.logutils import log_function from synapse.util.async import ObservableDeferred from synapse.util.logcontext import PreserveLoggingContext, preserve_fn @@ -143,6 +144,12 @@ class Notifier(object): self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() + + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + else: + self.federation_sender = None + self.state_handler = hs.get_state_handler() self.clock.looping_call( @@ -220,6 +227,9 @@ class Notifier(object): # poke any interested application service. self.appservice_handler.notify_interested_services(room_stream_id) + if self.federation_sender: + self.federation_sender.notify_new_events(room_stream_id) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -285,14 +295,7 @@ class Notifier(object): result = None if timeout: - # Will be set to a _NotificationListener that we'll be waiting on. - # Allows us to cancel it. - listener = None - - def timed_out(): - if listener: - listener.deferred.cancel() - timer = self.clock.call_later(timeout / 1000., timed_out) + end_time = self.clock.time_msec() + timeout prev_token = from_token while not result: @@ -303,6 +306,10 @@ class Notifier(object): if result: break + now = self.clock.time_msec() + if end_time <= now: + break + # Now we wait for the _NotifierUserStream to be told there # is a new token. # We need to supply the token we supplied to callback so @@ -310,11 +317,14 @@ class Notifier(object): prev_token = current_token listener = user_stream.new_listener(prev_token) with PreserveLoggingContext(): - yield listener.deferred + yield self.clock.time_bound_deferred( + listener.deferred, + time_out=(end_time - now) / 1000. + ) + except DeferredTimedOutError: + break except defer.CancelledError: break - - self.clock.cancel_call_later(timer, ignore_errs=True) else: current_token = user_stream.current_token result = yield callback(from_token, current_token) @@ -483,22 +493,27 @@ class Notifier(object): """ listener = _NotificationListener(None) - def timed_out(): - listener.deferred.cancel() + end_time = self.clock.time_msec() + timeout - timer = self.clock.call_later(timeout / 1000., timed_out) while True: listener.deferred = self.replication_deferred.observe() result = yield callback() if result: break + now = self.clock.time_msec() + if end_time <= now: + break + try: with PreserveLoggingContext(): - yield listener.deferred + yield self.clock.time_bound_deferred( + listener.deferred, + time_out=(end_time - now) / 1000. + ) + except DeferredTimedOutError: + break except defer.CancelledError: break - self.clock.cancel_call_later(timer, ignore_errs=True) - defer.returnValue(result) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index be55598c437b006e3c64ada5e6bd67c6b277fdc7..78b095c9039c433787c522b3202235350c5f545d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -87,12 +87,12 @@ class BulkPushRuleEvaluator: condition_cache = {} for uid, rules in self.rules_by_user.items(): - display_name = None - member_ev_id = context.current_state_ids.get((EventTypes.Member, uid)) - if member_ev_id: - member_ev = yield self.store.get_event(member_ev_id, allow_none=True) - if member_ev: - display_name = member_ev.content.get("displayname", None) + display_name = room_members.get(uid, {}).get("display_name", None) + if not display_name: + # Handle the case where we are pushing a membership event to + # that user, as they might not be already joined. + if event.type == EventTypes.Member and event.state_key == uid: + display_name = event.content.get("displayname", None) filtered = filtered_by_user[uid] if len(filtered) == 0: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index b9e41770ee2ec964f33a0b2325982f8555b465cc..3742a25b378af3d3361c07470f60f439907c15dd 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -49,8 +49,8 @@ CONDITIONAL_REQUIREMENTS = { "Jinja2>=2.8": ["Jinja2>=2.8"], "bleach>=1.4.2": ["bleach>=1.4.2"], }, - "ldap": { - "ldap3>=1.0": ["ldap3>=1.0"], + "matrix-synapse-ldap3": { + "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"], }, "psutil": { "psutil>=2.0.0": ["psutil>=2.0.0"], @@ -69,6 +69,7 @@ def requirements(config=None, include_conditional=False): def github_link(project, version, egg): return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) + DEPENDENCY_LINKS = { } @@ -156,6 +157,7 @@ def list_requirements(): result.append(requirement) return result + if __name__ == "__main__": import sys sys.stdout.writelines(req + "\n" for req in list_requirements()) diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..c05a50d7a68042d5dc95176344ec3603a326e7cb --- /dev/null +++ b/synapse/replication/expire_cache.py @@ -0,0 +1,60 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json_bytes, request_handler +from synapse.http.servlet import parse_json_object_from_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + + +class ExpireCacheResource(Resource): + """ + HTTP endpoint for expiring storage caches. + + POST /_synapse/replication/expire_cache HTTP/1.1 + Content-Type: application/json + + { + "invalidate": [ + { + "name": "func_name", + "keys": ["key1", "key2"] + } + ] + } + """ + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.store = hs.get_datastore() + self.version_string = hs.version_string + self.clock = hs.get_clock() + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @request_handler() + def _async_render_POST(self, request): + content = parse_json_object_from_request(request) + + for row in content["invalidate"]: + name = row["name"] + keys = tuple(row["keys"]) + + getattr(self.store, name).invalidate(keys) + + respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 5a14c51d235ae90e21a037e3d71587d74b9e7512..4616e9b34a5de63086866c88ace503ff559bd86f 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.http.server import request_handler, finish_request from synapse.replication.pusher_resource import PusherResource from synapse.replication.presence_resource import PresenceResource +from synapse.replication.expire_cache import ExpireCacheResource from synapse.api.errors import SynapseError from twisted.web.resource import Resource @@ -44,6 +45,7 @@ STREAM_NAMES = ( ("caches",), ("to_device",), ("public_rooms",), + ("federation",), ) @@ -116,11 +118,14 @@ class ReplicationResource(Resource): self.sources = hs.get_event_sources() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() + self.federation_sender = hs.get_federation_sender() self.notifier = hs.notifier self.clock = hs.get_clock() + self.config = hs.get_config() self.putChild("remove_pushers", PusherResource(hs)) self.putChild("syncing_users", PresenceResource(hs)) + self.putChild("expire_cache", ExpireCacheResource(hs)) def render_GET(self, request): self._async_render_GET(request) @@ -134,6 +139,7 @@ class ReplicationResource(Resource): pushers_token = self.store.get_pushers_stream_token() caches_token = self.store.get_cache_stream_token() public_rooms_token = self.store.get_current_public_room_stream_id() + federation_token = self.federation_sender.get_current_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -148,6 +154,7 @@ class ReplicationResource(Resource): caches_token, int(stream_token.to_device_key), int(public_rooms_token), + int(federation_token), )) @request_handler() @@ -164,8 +171,13 @@ class ReplicationResource(Resource): } request_streams["streams"] = parse_string(request, "streams") + federation_ack = parse_integer(request, "federation_ack", None) + def replicate(): - return self.replicate(request_streams, limit) + return self.replicate( + request_streams, limit, + federation_ack=federation_ack + ) writer = yield self.notifier.wait_for_replication(replicate, timeout) result = writer.finish() @@ -183,7 +195,7 @@ class ReplicationResource(Resource): finish_request(request) @defer.inlineCallbacks - def replicate(self, request_streams, limit): + def replicate(self, request_streams, limit, federation_ack=None): writer = _Writer() current_token = yield self.current_replication_token() logger.debug("Replicating up to %r", current_token) @@ -202,6 +214,7 @@ class ReplicationResource(Resource): yield self.caches(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams) + self.federation(writer, current_token, limit, request_streams, federation_ack) self.streams(writer, current_token, request_streams) logger.debug("Replicated %d rows", writer.total) @@ -462,7 +475,24 @@ class ReplicationResource(Resource): ) upto_token = _position_from_rows(public_rooms_rows, current_position) writer.write_header_and_rows("public_rooms", public_rooms_rows, ( - "position", "room_id", "visibility" + "position", "room_id", "visibility", "appservice_id", "network_id", + ), position=upto_token) + + def federation(self, writer, current_token, limit, request_streams, federation_ack): + if self.config.send_federation: + return + + current_position = current_token.federation + + federation = request_streams.get("federation") + + if federation is not None and federation != current_position: + federation_rows = self.federation_sender.get_replication_rows( + federation, limit, federation_ack=federation_ack, + ) + upto_token = _position_from_rows(federation_rows, current_position) + writer.write_header_and_rows("federation", federation_rows, ( + "position", "type", "content", ), position=upto_token) @@ -497,6 +527,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", + "federation", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f19540d6bbb13ba5a7a6e48d3afa577c1be08a1b..18076e0f3b549843d95ad09cfe670089dc689060 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore): else: self._cache_id_gen = None + self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" + self.http_client = hs.get_simple_http_client() + def stream_positions(self): pos = {} if self._cache_id_gen: @@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore): logger.info("Got unexpected cache_func: %r", cache_func) self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + txn.call_after(cache_func.invalidate, keys) + txn.call_after(self._send_invalidation_poke, cache_func, keys) + + @defer.inlineCallbacks + def _send_invalidation_poke(self, cache_func, keys): + try: + yield self.http_client.post_json_get_json(self.expire_cache_url, { + "invalidate": [{ + "name": cache_func.__name__, + "keys": list(keys), + }] + }) + except: + logger.exception("Failed to poke on expire_cache") diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 3bfd5e8213c45294af046dd6faef4419d299932f..cc860f9f9b047498f6cc48c9032297b1c68a0a0b 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore): "DeviceInboxStreamChangeCache", self._device_inbox_id_gen.get_current_token() ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + self._device_inbox_id_gen.get_current_token() + ) get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ + get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ delete_messages_for_device = DataStore.delete_messages_for_device.__func__ + delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__ def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() @@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore): self._device_inbox_id_gen.advance(int(stream["position"])) for row in stream["rows"]: stream_id = row[0] - user_id = row[1] - self._device_inbox_stream_cache.entity_has_changed( - user_id, stream_id - ) + entity = row[1] + + if entity.startswith("@"): + self._device_inbox_stream_cache.entity_has_changed( + entity, stream_id + ) + else: + self._device_federation_outbox_stream_cache.entity_has_changed( + entity, stream_id + ) return super(SlavedDeviceInboxStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 0c26e96e985a5da055790891b387a2cd5f5981fa..64f18bbb3e2e11b372ec0d7c088e759da2f4618e 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore from synapse.util.caches.stream_change_cache import StreamChangeCache import ujson as json +import logging + + +logger = logging.getLogger(__name__) + # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the @@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore): EventFederationStore.__dict__["_get_forward_extremeties_for_room"] ) + get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__ + + get_federation_out_pos = DataStore.get_federation_out_pos.__func__ + update_federation_out_pos = DataStore.update_federation_out_pos.__func__ + def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() @@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore): stream = result.get("events") if stream: self._stream_id_gen.advance(int(stream["position"])) + + if stream["rows"]: + logger.info("Got %d event rows", len(stream["rows"])) + for row in stream["rows"]: self._process_replication_row( row, backfilled=False, state_resets=state_resets diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 23c613863f43adc2fe715fa880f170fb1172b520..6df9a25ef311cea88a2b60fe97fdc507cc9f6235 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -15,6 +15,7 @@ from ._base import BaseSlavedStore from synapse.storage import DataStore +from synapse.storage.room import RoomStore from ._slaved_id_tracker import SlavedIdTracker @@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore): DataStore.get_current_public_room_stream_id.__func__ ) get_public_room_ids_at_stream_id = ( - DataStore.get_public_room_ids_at_stream_id.__func__ + RoomStore.__dict__["get_public_room_ids_at_stream_id"] ) get_public_room_ids_at_stream_id_txn = ( DataStore.get_public_room_ids_at_stream_id_txn.__func__ diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index 6f2ba98af59ad24a0b75fce45e83a664b4b67046..fbb58f35da0d21af7e18be0fc5c30b991660a037 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from ._base import BaseSlavedStore from synapse.storage import DataStore from synapse.storage.transactions import TransactionStore @@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore class TransactionStore(BaseSlavedStore): get_destination_retry_timings = TransactionStore.__dict__[ "get_destination_retry_timings" - ].orig + ] _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__ + set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__ + _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__ - # For now, don't record the destination rety timings - def set_destination_retry_timings(*args, **kwargs): - return defer.succeed(None) + prep_send_transaction = DataStore.prep_send_transaction.__func__ + delivered_txn = DataStore.delivered_txn.__func__ diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 09d08315941c7301b03eaea3ab4eb5f86f05e77c..8930f1826fcb8a3d7e17f3a887064f71e059e1a7 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -31,6 +31,7 @@ logger = logging.getLogger(__name__) def register_servlets(hs, http_server): ClientDirectoryServer(hs).register(http_server) ClientDirectoryListServer(hs).register(http_server) + ClientAppserviceDirectoryListServer(hs).register(http_server) class ClientDirectoryServer(ClientV1RestServlet): @@ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet): ) defer.returnValue((200, {})) + + +class ClientAppserviceDirectoryListServer(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$" + ) + + def __init__(self, hs): + super(ClientAppserviceDirectoryListServer, self).__init__(hs) + self.store = hs.get_datastore() + self.handlers = hs.get_handlers() + + def on_PUT(self, request, network_id, room_id): + content = parse_json_object_from_request(request) + visibility = content.get("visibility", "public") + return self._edit(request, network_id, room_id, visibility) + + def on_DELETE(self, request, network_id, room_id): + return self._edit(request, network_id, room_id, "private") + + @defer.inlineCallbacks + def _edit(self, request, network_id, room_id, visibility): + requester = yield self.auth.get_user_by_req(request) + if not requester.app_service: + raise AuthError( + 403, "Only appservices can edit the appservice published room list" + ) + + yield self.handlers.directory_handler.edit_published_appservice_room_list( + requester.app_service.id, network_id, room_id, visibility, + ) + + defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 345018a8fc16f85bc467aa532e25ebdd4b8cc431..093bc072f41f49a4710e705991c35d1245cc5565 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet): password=login_submission["password"], ) device_id = yield self._register_device(user_id, login_submission) - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name") - ) + access_token = yield auth_handler.get_access_token_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, "device_id": device_id, } @@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet): yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) device_id = yield self._register_device(user_id, login_submission) - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - user_id, device_id, - login_submission.get("initial_device_display_name") - ) + access_token = yield auth_handler.get_access_token_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name"), ) result = { "user_id": user_id, # may have changed "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, "device_id": device_id, } @@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet): device_id = yield self._register_device( registered_user_id, login_submission ) - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id, - login_submission.get("initial_device_display_name") - ) + access_token = yield auth_handler.get_access_token_for_user_id( + registered_user_id, device_id, + login_submission.get("initial_device_display_name"), ) + result = { "user_id": registered_user_id, "access_token": access_token, - "refresh_token": refresh_token, "home_server": self.hs.hostname, } else: diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index b5a76fefac679ca90c197d79cb89e36966f28aa3..ecf7e311a952827a0f40c90dae03e67deea332f5 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet): def __init__(self, hs): super(CreateUserRestServlet, self).__init__(hs) self.store = hs.get_datastore() - self.direct_user_creation_max_duration = hs.config.user_creation_max_duration self.handlers = hs.get_handlers() @defer.inlineCallbacks @@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet): if "displayname" not in user_json: raise SynapseError(400, "Expected 'displayname' key.") - if "duration_seconds" not in user_json: - raise SynapseError(400, "Expected 'duration_seconds' key.") - localpart = user_json["localpart"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8") - duration_seconds = 0 - try: - duration_seconds = int(user_json["duration_seconds"]) - except ValueError: - raise SynapseError(400, "Failed to parse 'duration_seconds'") - if duration_seconds > self.direct_user_creation_max_duration: - duration_seconds = self.direct_user_creation_max_duration password_hash = user_json["password_hash"].encode("utf-8") \ if user_json.get("password_hash") else None @@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet): requester=requester, localpart=localpart, displayname=displayname, - duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 3fb1f2deb33a7c5f04b04fd48e166d084b3df60b..eead435bfd80488111a640dcbd394c36a06a90f1 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import Filter -from synapse.types import UserID, RoomID, RoomAlias +from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID from synapse.events.utils import serialize_event, format_event_for_client_v2 from synapse.http.servlet import ( parse_json_object_from_request, parse_string, parse_integer @@ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet): since_token = content.get("since", None) search_filter = content.get("filter", None) + include_all_networks = content.get("include_all_networks", False) + third_party_instance_id = content.get("third_party_instance_id", None) + + if include_all_networks: + network_tuple = None + if third_party_instance_id is not None: + raise SynapseError( + 400, "Can't use include_all_networks with an explicit network" + ) + elif third_party_instance_id is None: + network_tuple = ThirdPartyInstanceID(None, None) + else: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + handler = self.hs.get_room_list_handler() if server: data = yield handler.get_remote_public_room_list( @@ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet): limit=limit, since_token=since_token, search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, ) else: data = yield handler.get_local_public_room_list( limit=limit, since_token=since_token, search_filter=search_filter, + network_tuple=network_tuple, ) defer.returnValue((200, data)) @@ -369,6 +386,24 @@ class RoomMemberListRestServlet(ClientV1RestServlet): })) +class JoinedRoomMemberListRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") + + def __init__(self, hs): + super(JoinedRoomMemberListRestServlet, self).__init__(hs) + self.state = hs.get_state_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id): + yield self.auth.get_user_by_req(request) + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + + defer.returnValue((200, { + "joined": users_with_profile + })) + + # TODO: Needs better unit testing class RoomMessageListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") @@ -692,6 +727,22 @@ class SearchRestServlet(ClientV1RestServlet): defer.returnValue((200, results)) +class JoinedRoomsRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/joined_rooms$") + + def __init__(self, hs): + super(JoinedRoomsRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + rooms = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = set(r.room_id for r in rooms) # Ensure they're unique. + defer.returnValue((200, {"joined_rooms": list(room_ids)})) + + def register_txn_path(servlet, regex_string, http_server, with_get=False): """Registers a transaction-based path. @@ -727,6 +778,7 @@ def register_servlets(hs, http_server): RoomStateEventRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) + JoinedRoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) RoomForgetRestServlet(hs).register(http_server) @@ -738,4 +790,5 @@ def register_servlets(hs, http_server): RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) + JoinedRoomsRestServlet(hs).register(http_server) RoomEventContext(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 3ba0b0fc0738f9d10f80a26dca54371ca4fef1ab..a1feaf3d54edc7a44f41452ebe2e1ab0aea489c9 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -39,7 +39,7 @@ class DevicesRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) devices = yield self.device_handler.get_devices_by_user( requester.user.to_string() ) @@ -63,7 +63,7 @@ class DeviceRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = yield self.device_handler.get_device( requester.user.to_string(), device_id, @@ -99,7 +99,7 @@ class DeviceRestServlet(servlet.RestServlet): @defer.inlineCallbacks def on_PUT(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) body = servlet.parse_json_object_from_request(request) yield self.device_handler.update_device( diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f185f9a7744de86d424596bd710c5693495e143b..46789775b9e3e83447db315509ee4d74a2dfe650 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -65,7 +65,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -94,10 +94,6 @@ class KeyUploadServlet(RestServlet): class KeyQueryServlet(RestServlet): """ - GET /keys/query/<user_id> HTTP/1.1 - - GET /keys/query/<user_id>/<device_id> HTTP/1.1 - POST /keys/query HTTP/1.1 Content-Type: application/json { @@ -131,11 +127,7 @@ class KeyQueryServlet(RestServlet): """ PATTERNS = client_v2_patterns( - "/keys/query(?:" - "/(?P<user_id>[^/]*)(?:" - "/(?P<device_id>[^/]*)" - ")?" - ")?", + "/keys/query$", releases=() ) @@ -149,31 +141,16 @@ class KeyQueryServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks - def on_POST(self, request, user_id, device_id): - yield self.auth.get_user_by_req(request) + def on_POST(self, request): + yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.query_devices(body, timeout) defer.returnValue((200, result)) - @defer.inlineCallbacks - def on_GET(self, request, user_id, device_id): - requester = yield self.auth.get_user_by_req(request) - timeout = parse_integer(request, "timeout", 10 * 1000) - auth_user_id = requester.user.to_string() - user_id = user_id if user_id else auth_user_id - device_ids = [device_id] if device_id else [] - result = yield self.e2e_keys_handler.query_devices( - {"device_keys": {user_id: device_ids}}, - timeout, - ) - defer.returnValue((200, result)) - class OneTimeKeyServlet(RestServlet): """ - GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1 - POST /keys/claim HTTP/1.1 { "one_time_keys": { @@ -191,9 +168,7 @@ class OneTimeKeyServlet(RestServlet): """ PATTERNS = client_v2_patterns( - "/keys/claim(?:/?|(?:/" - "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)" - ")?)", + "/keys/claim$", releases=() ) @@ -203,18 +178,8 @@ class OneTimeKeyServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks - def on_GET(self, request, user_id, device_id, algorithm): - yield self.auth.get_user_by_req(request) - timeout = parse_integer(request, "timeout", 10 * 1000) - result = yield self.e2e_keys_handler.claim_one_time_keys( - {"one_time_keys": {user_id: {device_id: algorithm}}}, - timeout, - ) - defer.returnValue((200, result)) - - @defer.inlineCallbacks - def on_POST(self, request, user_id, device_id, algorithm): - yield self.auth.get_user_by_req(request) + def on_POST(self, request): + yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.claim_one_time_keys( diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 891cef99c6845ff4d4cc5530182c1b49c12c01d7..1fbff2edd8fca25253e50fd22c942cb83f1b7ec3 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -36,7 +36,7 @@ class ReceiptRestServlet(RestServlet): super(ReceiptRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.receipts_handler = hs.get_handlers().receipts_handler + self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6cfb20866b1bc49e442ccfe1a76719632864c775..3e7a285e10589e815287930d3ea0884120383679 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -15,6 +15,7 @@ from twisted.internet import defer +import synapse from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError @@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() + body = parse_json_object_from_request(request) + kind = "user" if "kind" in request.args: kind = request.args["kind"][0] if kind == "guest": - ret = yield self._do_guest_registration() + ret = yield self._do_guest_registration(body) defer.returnValue(ret) return elif kind != "user": @@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - body = parse_json_object_from_request(request) - # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None @@ -373,8 +374,7 @@ class RegisterRestServlet(RestServlet): def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user - Allocates device_id if one was not given; also creates access_token - and refresh_token. + Allocates device_id if one was not given; also creates access_token. Args: (str) user_id: full canonical @user:id @@ -385,8 +385,8 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token, refresh_token = ( - yield self.auth_handler.get_login_tuple_for_user_id( + access_token = ( + yield self.auth_handler.get_access_token_for_user_id( user_id, device_id=device_id, initial_display_name=params.get("initial_device_display_name") ) @@ -396,7 +396,6 @@ class RegisterRestServlet(RestServlet): "user_id": user_id, "access_token": access_token, "home_server": self.hs.hostname, - "refresh_token": refresh_token, "device_id": device_id, }) @@ -421,20 +420,28 @@ class RegisterRestServlet(RestServlet): ) @defer.inlineCallbacks - def _do_guest_registration(self): + def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: defer.returnValue((403, "Guest access is disabled")) user_id, _ = yield self.registration_handler.register( generate_token=False, make_guest=True ) + + # we don't allow guests to specify their own device_id, because + # we have nowhere to store it. + device_id = synapse.api.auth.GUEST_DEVICE_ID + initial_display_name = params.get("initial_device_display_name") + self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + access_token = self.auth_handler.generate_access_token( user_id, ["guest = true"] ) - # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok - # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, + "device_id": device_id, "access_token": access_token, "home_server": self.hs.hostname, })) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index ac660669f377165d3a5160e5b63c794d0a5e4d5d..d607bd29709d7f3930baa6ed177f6d39a03a9a95 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -50,7 +50,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): @defer.inlineCallbacks def _put(self, request, message_type, txn_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 6fc63715aad89d1fbf3884f00e4ea02591d69772..7199ec883a321111bb41672eff69841145fb5ee1 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() joined = self.encode_joined( - sync_result.joined, time_now, requester.access_token_id + sync_result.joined, time_now, requester.access_token_id, filter.event_fields ) invited = self.encode_invited( @@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id + sync_result.archived, time_now, requester.access_token_id, filter.event_fields ) response_content = { @@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": formatted} - def encode_joined(self, rooms, time_now, token_id): + def encode_joined(self, rooms, time_now, token_id, event_fields): """ Encode the joined rooms in a sync result @@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: the joined rooms list, in our response format @@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id + room, time_now, token_id, only_fields=event_fields ) return joined @@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, time_now, token_id): + def encode_archived(self, rooms, time_now, token_id, event_fields): """ Encode the archived rooms in a sync result @@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: The invited rooms list, in our response format @@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id, joined=False + room, time_now, token_id, joined=False, only_fields=event_fields ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True): + def encode_room(room, time_now, token_id, joined=True, only_fields=None): """ Args: room (JoinedSyncResult|ArchivedSyncResult): sync result for a @@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet): of transaction IDs joined (bool): True if the user is joined to this room - will mean we handle ephemeral events - + only_fields(list<str>): Optional. The list of event fields to include. Returns: dict[str, object]: the room, encoded in our response format """ @@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet): return serialize_event( event, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, + only_event_fields=only_fields, ) state_dict = room.state diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 0d312c91d4273dc2f803d199d0b3959cc99ec1af..6e76b9e9c2f6d4af1c11059a9dcd1b4d4ff17dac 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -15,8 +15,8 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - self.hs = hs - self.store = hs.get_datastore() @defer.inlineCallbacks def on_POST(self, request): - body = parse_json_object_from_request(request) - try: - old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_auth_handler() - refresh_result = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token - ) - (user_id, new_refresh_token, device_id) = refresh_result - new_access_token = yield auth_handler.issue_access_token( - user_id, device_id - ) - defer.returnValue((200, { - "access_token": new_access_token, - "refresh_token": new_refresh_token, - })) - except KeyError: - raise SynapseError(400, "Missing required key 'refresh_token'.") - except StoreError: - raise AuthError(403, "Did not recognize refresh token") + raise AuthError(403, "tokenrefresh is no longer supported.") def register_servlets(hs, http_server): diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 33f35fb44eb8f2b5b004ece6e975bcf2a4e36378..99760d622f788973f6535c8de3a4efffefac26d7 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -381,7 +381,10 @@ def _calc_og(tree, media_uri): if 'og:title' not in og: # do some basic spidering of the HTML title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - og['og:title'] = title[0].text.strip() if title else None + if title and title[0].text is not None: + og['og:title'] = title[0].text.strip() + else: + og['og:title'] = None if 'og:image' not in og: # TODO: extract a favicon failing all else @@ -543,5 +546,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # We always add an ellipsis because at the very least # we chopped mid paragraph. - description = new_desc.strip() + "…" + description = new_desc.strip() + u"…" return description if description else None diff --git a/synapse/server.py b/synapse/server.py index 374124a147fdc7538bc1adc7f020807177103f9f..0bfb4112698ccdc0f539a6b0b69f102b9d308683 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -32,6 +32,9 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.federation import initialize_http_replication +from synapse.federation.send_queue import FederationRemoteSendQueue +from synapse.federation.transport.client import TransportLayerClient +from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler @@ -44,6 +47,7 @@ from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler +from synapse.handlers.receipts import ReceiptsHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier @@ -124,6 +128,9 @@ class HomeServer(object): 'http_client_context_factory', 'simple_http_client', 'media_repository', + 'federation_transport_client', + 'federation_sender', + 'receipts_handler', ] def __init__(self, hostname, **kwargs): @@ -265,9 +272,30 @@ class HomeServer(object): def build_media_repository(self): return MediaRepository(self) + def build_federation_transport_client(self): + return TransportLayerClient(self) + + def build_federation_sender(self): + if self.should_send_federation(): + return TransactionQueue(self) + elif not self.config.worker_app: + return FederationRemoteSendQueue(self) + else: + raise Exception("Workers cannot send federation traffic") + + def build_receipts_handler(self): + return ReceiptsHandler(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) + def should_send_federation(self): + "Should this server be sending federation traffic directly?" + return self.config.send_federation and ( + not self.config.worker_app + or self.config.worker_app == "synapse.app.federation_sender" + ) + def _make_dependency_method(depname): def _get(hs): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 9996f195a0a488a878e12213ae7d3e3ed092a0bf..fe936b3e6204308d29cdc786cee47c710713c3df 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore, self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") @@ -223,6 +222,7 @@ class DataStore(RoomMemberStore, RoomStore, ) self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() super(DataStore, self).__init__(hs) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index d828d6ee1d8f4425fcf9c91dddebe3e885e7444e..b62c459d8b03994a20bf76ac97dccf6e8a21efcf 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -561,12 +561,17 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + else: + where = "" + sql = ( - "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" + "SELECT %(retcol)s FROM %(table)s %(where)s" ) % { "retcol": retcol, "table": table, - "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + "where": where, } txn.execute(sql, keyvalues.values()) @@ -744,10 +749,15 @@ class SQLBaseStore(object): @staticmethod def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - update_sql = "UPDATE %s SET %s WHERE %s" % ( + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( table, ", ".join("%s = ?" % (k,) for k in updatevalues), - " AND ".join("%s = ?" % (k,) for k in keyvalues) + where, ) txn.execute( diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 3d5994a58000b72f4e20f9d500c15d54a454d1e0..514570561f7ce859552a4fab6b82cfacb2e1bd08 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -39,6 +39,14 @@ class ApplicationServiceStore(SQLBaseStore): def get_app_services(self): return self.services_cache + def get_if_app_services_interested_in_user(self, user_id): + """Check if the user is one associated with an app service + """ + for service in self.services_cache: + if service.is_interested_in_user(user_id): + return True + return False + def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index f640e7371460b3c41504e032dde6d4a3abbde135..2821eb89c9b83da8c958758087041921b9a7c56a 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -242,7 +242,7 @@ class DeviceInboxStore(SQLBaseStore): device_id(str): The recipient device_id. up_to_stream_id(int): Where to delete messages up to. Returns: - A deferred that resolves when the messages have been deleted. + A deferred that resolves to the number of messages deleted. """ def delete_messages_for_device_txn(txn): sql = ( @@ -251,6 +251,7 @@ class DeviceInboxStore(SQLBaseStore): " AND stream_id <= ?" ) txn.execute(sql, (user_id, device_id, up_to_stream_id)) + return txn.rowcount return self.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn @@ -269,27 +270,29 @@ class DeviceInboxStore(SQLBaseStore): return defer.succeed([]) def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) sql = ( - "SELECT stream_id FROM device_inbox" + "SELECT stream_id, user_id" + " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY stream_id" " ORDER BY stream_id ASC" - " LIMIT ?" ) - txn.execute(sql, (last_pos, current_pos, limit)) - stream_ids = txn.fetchall() - if not stream_ids: - return [] - max_stream_id_in_limit = stream_ids[-1] + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() sql = ( - "SELECT stream_id, user_id, device_id, message_json" - " FROM device_inbox" + "SELECT stream_id, destination" + " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" ) - txn.execute(sql, (last_pos, max_stream_id_in_limit)) - return txn.fetchall() + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn.fetchall()) + + return rows return self.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 9cd923eb93e184bb6a85c5a7f3b9b0b0fc912866..7de3e8c58cbff7272b8115c39908d565353b3ab6 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -39,6 +39,14 @@ class EventPushActionsStore(SQLBaseStore): columns=["user_id", "stream_ordering"], ) + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ Args: @@ -88,8 +96,11 @@ class EventPushActionsStore(SQLBaseStore): topological_ordering, stream_ordering ) + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 sql = ( - "SELECT sum(notif), sum(highlight)" + "SELECT count(*)" " FROM event_push_actions ea" " WHERE" " user_id = ?" @@ -99,13 +110,27 @@ class EventPushActionsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id)) row = txn.fetchone() - if row: - return { - "notify_count": row[0] or 0, - "highlight_count": row[1] or 0, - } - else: - return {"notify_count": 0, "highlight_count": 0} + notify_count = row[0] if row else 0 + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return { + "notify_count": notify_count, + "highlight_count": highlight_count, + } ret = yield self.runInteraction( "get_unread_event_push_actions_by_room", diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 49aeb953bd01fed3de74599aaee6a7377b0ea2d0..ecb79c07ef8630ceac1fe188313734f700eb652d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -54,6 +54,7 @@ def encode_json(json_object): else: return json.dumps(json_object, ensure_ascii=False) + # These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 5248736816bac2b95f6081a640e93020b827cca4..a2ccc66ea7c8768293a7bc406f919a9428aeb587 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -16,6 +16,7 @@ from twisted.internet import defer from ._base import SQLBaseStore +from synapse.api.errors import SynapseError, Codes from synapse.util.caches.descriptors import cachedInlineCallbacks import simplejson as json @@ -24,6 +25,13 @@ import simplejson as json class FilteringStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_user_filter(self, user_localpart, filter_id): + # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail + # with a coherent error message rather than 500 M_UNKNOWN. + try: + int(filter_id) + except ValueError: + raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) + def_json = yield self._simple_select_one_onecol( table="user_filters", keyvalues={ diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6576a3009888b9e463db629dfb98133e4913e8ac..e46ae6502ec88792e6ef3bd33e58126cc5ccee31 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 38 +SCHEMA_VERSION = 39 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 21d0696640a3e63fa2b7bb9a9036ecc304e605e9..7460f98a1fec58420fd04260008cdf2c210532c9 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -37,6 +37,13 @@ class UserPresenceState(namedtuple("UserPresenceState", status_msg (str): User set status message. """ + def as_dict(self): + return dict(self._asdict()) + + @staticmethod + def from_dict(d): + return UserPresenceState(**d) + def copy_and_replace(self, **kwargs): return self._replace(**kwargs) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 49721656b64423f459865527618630d2fea3ee28..cbec255966ad0ce3c8f8231fe8a8c7ff78adbc1a 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -156,12 +156,20 @@ class PushRuleStore(SQLBaseStore): event=event, ) - local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u)) + # We ignore app service users for now. This is so that we don't fill + # up the `get_if_users_have_pushers` cache with AS entries that we + # know don't have pushers, nor even read receipts. + local_users_in_room = set( + u for u in users_in_room + if self.hs.is_mine_id(u) + and not self.get_if_app_services_interested_in_user(u) + ) # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate, + local_users_in_room, + on_invalidate=cache_context.invalidate, ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 9747a04a9a87323b05ad2bc1120b7c30b5118e82..f72d15f5ed33a0c85d9606f64c9acb94edfbe6d1 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -405,7 +405,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_current_token() + max_persisted_id = self._receipts_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index e404fa72de8012d9348e25c0d616da88995cb0e0..983a8ec52bcd156a882c4691a1ee1ef7123002a4 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): desc="add_access_token_to_user", ) - @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token, device_id=None): - """Adds a refresh token for the given user. - - Args: - user_id (str): The user ID. - token (str): The new refresh token to add. - device_id (str): ID of the device to associate with the access - token - Raises: - StoreError if there was a problem adding this. - """ - next_id = self._refresh_tokens_id_gen.get_next() - - yield self._simple_insert( - "refresh_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token, - "device_id": device_id, - }, - desc="add_refresh_token_to_user", - ) - def register(self, user_id, token=None, password_hash=None, was_guest=False, make_guest=False, appservice_id=None, create_profile_with_localpart=None, admin=False): @@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): token ) - def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new one. - - Doing so invalidates the old refresh token - refresh tokens are single - use. - - Args: - refresh_token (str): The refresh token of a user. - token_generator (fn: str -> str): Function which, when given a - user ID, returns a unique refresh token for that user. This - function must never return the same value twice. - Returns: - tuple of (user_id, new_refresh_token, device_id) - Raises: - StoreError if no user was found with that refresh token. - """ - return self.runInteraction( - "exchange_refresh_token", - self._exchange_refresh_token, - refresh_token, - token_generator - ) - - def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" - txn.execute(sql, (old_token,)) - rows = self.cursor_to_dict(txn) - if not rows: - raise StoreError(403, "Did not recognize refresh token") - user_id = rows[0]["user_id"] - device_id = rows[0]["device_id"] - - # TODO(danielwh): Maybe perform a validation on the macaroon that - # macaroon.user_id == user_id. - - new_token = token_generator(user_id) - sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" - txn.execute(sql, (new_token, old_token,)) - - return user_id, new_token, device_id - @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 11813b44f60c85453a4ad065dda11c37959e2b81..8a2fe2fdf572b188175a316e1a000c86015ffcd5 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore from .engines import PostgresEngine, Sqlite3Engine @@ -106,7 +107,11 @@ class RoomStore(SQLBaseStore): entries = self._simple_select_list_txn( txn, table="public_room_list_stream", - keyvalues={"room_id": room_id}, + keyvalues={ + "room_id": room_id, + "appservice_id": None, + "network_id": None, + }, retcols=("stream_id", "visibility"), ) @@ -124,6 +129,8 @@ class RoomStore(SQLBaseStore): "stream_id": next_id, "room_id": room_id, "visibility": is_public, + "appservice_id": None, + "network_id": None, } ) @@ -132,6 +139,87 @@ class RoomStore(SQLBaseStore): "set_room_is_public", set_room_is_public_txn, next_id, ) + self.hs.get_notifier().on_new_replication_data() + + @defer.inlineCallbacks + def set_room_is_public_appservice(self, room_id, appservice_id, network_id, + is_public): + """Edit the appservice/network specific public room list. + + Each appservice can have a number of published room lists associated + with them, keyed off of an appservice defined `network_id`, which + basically represents a single instance of a bridge to a third party + network. + + Args: + room_id (str) + appservice_id (str) + network_id (str) + is_public (bool): Whether to publish or unpublish the room from the + list. + """ + def set_room_is_public_appservice_txn(txn, next_id): + if is_public: + try: + self._simple_insert_txn( + txn, + table="appservice_room_list", + values={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id + }, + ) + except self.database_engine.module.IntegrityError: + # We've already inserted, nothing to do. + return + else: + self._simple_delete_txn( + txn, + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id + }, + ) + + entries = self._simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={ + "room_id": room_id, + "appservice_id": appservice_id, + "network_id": network_id, + }, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + "appservice_id": appservice_id, + "network_id": network_id, + } + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction( + "set_room_is_public_appservice", + set_room_is_public_appservice_txn, next_id, + ) + self.hs.get_notifier().on_new_replication_data() def get_public_room_ids(self): return self._simple_select_onecol( @@ -259,38 +347,96 @@ class RoomStore(SQLBaseStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def get_public_room_ids_at_stream_id(self, stream_id): + @cached(num_args=2, max_entries=100) + def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): + """Get pulbic rooms for a particular list, or across all lists. + + Args: + stream_id (int) + network_tuple (ThirdPartyInstanceID): The list to use (None, None) + means the main list, None means all lsits. + """ return self.runInteraction( "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, stream_id + self.get_public_room_ids_at_stream_id_txn, + stream_id, network_tuple=network_tuple ) - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id): + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, + network_tuple): return { rm - for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items() + for rm, vis in self.get_published_at_stream_id_txn( + txn, stream_id, network_tuple=network_tuple + ).items() if vis } - def get_published_at_stream_id_txn(self, txn, stream_id): - sql = (""" - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id + def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): + if network_tuple: + # We want to get from a particular list. No aggregation required. + + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream + INNER JOIN ( + SELECT room_id, max(stream_id) AS stream_id + FROM public_room_list_stream + WHERE stream_id <= ? %s + GROUP BY room_id + ) grouped USING (room_id, stream_id) + """) + + if network_tuple.appservice_id is not None: + txn.execute( + sql % ("AND appservice_id = ? AND network_id = ?",), + (stream_id, network_tuple.appservice_id, network_tuple.network_id,) + ) + else: + txn.execute( + sql % ("AND appservice_id IS NULL",), + (stream_id,) + ) + return dict(txn.fetchall()) + else: + # We want to get from all lists, so we need to aggregate the results + + logger.info("Executing full list") + + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id - ) grouped USING (room_id, stream_id) - """) + INNER JOIN ( + SELECT + room_id, max(stream_id) AS stream_id, appservice_id, + network_id + FROM public_room_list_stream + WHERE stream_id <= ? + GROUP BY room_id, appservice_id, network_id + ) grouped USING (room_id, stream_id) + """) - txn.execute(sql, (stream_id,)) - return dict(txn.fetchall()) + txn.execute( + sql, + (stream_id,) + ) + + results = {} + # A room is visible if its visible on any list. + for room_id, visibility in txn.fetchall(): + results[room_id] = bool(visibility) or results.get(room_id, False) - def get_public_room_changes(self, prev_stream_id, new_stream_id): + return results + + def get_public_room_changes(self, prev_stream_id, new_stream_id, + network_tuple): def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id) + then_rooms = self.get_public_room_ids_at_stream_id_txn( + txn, prev_stream_id, network_tuple + ) - now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id) + now_rooms_dict = self.get_published_at_stream_id_txn( + txn, new_stream_id, network_tuple + ) now_rooms_visible = set( rm for rm, vis in now_rooms_dict.items() if vis @@ -311,7 +457,8 @@ class RoomStore(SQLBaseStore): def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(txn): sql = (""" - SELECT stream_id, room_id, visibility FROM public_room_list_stream + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 866d64e679066daf3044b93b38f0735056983651..946d5a81cc6a047996c7d6e88d66d35f5e9593ed 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -24,6 +24,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging +import ujson as json logger = logging.getLogger(__name__) @@ -34,7 +35,15 @@ RoomsForUser = namedtuple( ) +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" + + class RoomMemberStore(SQLBaseStore): + def __init__(self, hs): + super(RoomMemberStore, self).__init__(hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. @@ -49,6 +58,8 @@ class RoomMemberStore(SQLBaseStore): "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), } for event in events ] @@ -398,7 +409,7 @@ class RoomMemberStore(SQLBaseStore): table="room_memberships", column="event_id", iterable=member_event_ids, - retcols=['user_id'], + retcols=['user_id', 'display_name', 'avatar_url'], keyvalues={ "membership": Membership.JOIN, }, @@ -406,11 +417,21 @@ class RoomMemberStore(SQLBaseStore): desc="_get_joined_users_from_context", ) - users_in_room = set(row["user_id"] for row in rows) + users_in_room = { + row["user_id"]: { + "display_name": row["display_name"], + "avatar_url": row["avatar_url"], + } + for row in rows + } + if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room.add(event.state_key) + users_in_room[event.state_key] = { + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } defer.returnValue(users_in_room) @@ -448,3 +469,78 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(True) defer.returnValue(False) + + @defer.inlineCallbacks + def _background_add_membership_profile(self, progress, batch_size): + target_min_stream_id = progress.get( + "target_min_stream_id_inclusive", self._min_stream_order_on_start + ) + max_stream_id = progress.get( + "max_stream_id_exclusive", self._stream_order_on_start + 1 + ) + + INSERT_CLUMP_SIZE = 1000 + + def add_membership_profile_txn(txn): + sql = (""" + SELECT stream_ordering, event_id, events.room_id, content + FROM events + INNER JOIN room_memberships USING (event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND type = 'm.room.member' + ORDER BY stream_ordering DESC + LIMIT ? + """) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = self.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + to_update = [] + for row in rows: + event_id = row["event_id"] + room_id = row["room_id"] + try: + content = json.loads(row["content"]) + except: + continue + + display_name = content.get("displayname", None) + avatar_url = content.get("avatar_url", None) + + if display_name or avatar_url: + to_update.append(( + display_name, avatar_url, event_id, room_id + )) + + to_update_sql = (""" + UPDATE room_memberships SET display_name = ?, avatar_url = ? + WHERE event_id = ? AND room_id = ? + """) + for index in range(0, len(to_update), INSERT_CLUMP_SIZE): + clump = to_update[index:index + INSERT_CLUMP_SIZE] + txn.executemany(to_update_sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn + ) + + if not result: + yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + + defer.returnValue(result) diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/schema/delta/39/appservice_room_list.sql new file mode 100644 index 0000000000000000000000000000000000000000..74bdc4907380997fb7fa9aa6c2047b4e5c2bf04c --- /dev/null +++ b/synapse/storage/schema/delta/39/appservice_room_list.sql @@ -0,0 +1,29 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE appservice_room_list( + appservice_id TEXT NOT NULL, + network_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +-- Each appservice can have multiple published room lists associated with them, +-- keyed of a particular network_id +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( + appservice_id, network_id, room_id +); + +ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT; +ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT; diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql new file mode 100644 index 0000000000000000000000000000000000000000..00be801e901fbe942fa5a268c8bd1490b6283483 --- /dev/null +++ b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/schema/delta/39/event_push_index.sql new file mode 100644 index 0000000000000000000000000000000000000000..de2ad93e5cdcb20fd35194f812222ab896c86a41 --- /dev/null +++ b/synapse/storage/schema/delta/39/event_push_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_push_actions_highlights_index', '{}'); diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/schema/delta/39/federation_out_position.sql new file mode 100644 index 0000000000000000000000000000000000000000..5af814290bb0a617c18166b804dcacadbec5b690 --- /dev/null +++ b/synapse/storage/schema/delta/39/federation_out_position.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + CREATE TABLE federation_stream_position( + type TEXT NOT NULL, + stream_id INTEGER NOT NULL + ); + + INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); + INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/schema/delta/39/membership_profile.sql new file mode 100644 index 0000000000000000000000000000000000000000..1bf911c8abb0f1b2d6e08a12b673aa39a32c2d37 --- /dev/null +++ b/synapse/storage/schema/delta/39/membership_profile.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_memberships ADD COLUMN display_name TEXT; +ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT; + +INSERT into background_updates (update_name, progress_json) + VALUES ('room_membership_profile_update', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 49abf0ac748a21c377bdd4f3444f77648843ab8d..23e7ad9922c66a1c0f90bd4cf6597c784c37ce24 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -653,7 +653,10 @@ class StateStore(SQLBaseStore): else: state_dict = results[group] - state_dict.update(group_state_dict) + state_dict.update({ + (intern_string(k[0]), intern_string(k[1])): v + for k, v in group_state_dict.items() + }) self._state_group_cache.update( cache_seq_num, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 888b1cb35dc594d076d8e99e95ef914d81881014..2dc24951c4deb8ecc5484e8ba623090bb5dde42e 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -541,6 +541,9 @@ class StreamStore(SQLBaseStore): def get_room_max_stream_ordering(self): return self._stream_id_gen.get_current_token() + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() + def get_stream_token_for_event(self, event_id): """The stream token for an event Args: @@ -765,3 +768,50 @@ class StreamStore(SQLBaseStore): "token": end_token, }, } + + @defer.inlineCallbacks + def get_all_new_events_stream(self, from_id, current_id, limit): + """Get all new events""" + + def get_all_new_events_stream_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (from_id, current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.runInteraction( + "get_all_new_events_stream", get_all_new_events_stream_txn, + ) + + events = yield self._get_events(event_ids) + + defer.returnValue((upper_bound, events)) + + def get_federation_out_pos(self, typ): + return self._simple_select_one_onecol( + table="federation_stream_position", + retcol="stream_id", + keyvalues={"type": typ}, + desc="get_federation_out_pos" + ) + + def update_federation_out_pos(self, typ, stream_id): + return self._simple_update_one( + table="federation_stream_position", + keyvalues={"type": typ}, + updatevalues={"stream_id": stream_id}, + desc="update_federation_out_pos", + ) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index adab520c7840e8d489e2a132cbafb683c42b2dff..809fdd311f10c78951e13ebc8a7506812deb5abe 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -200,25 +200,48 @@ class TransactionStore(SQLBaseStore): def _set_destination_retry_timings(self, txn, destination, retry_last_ts, retry_interval): - txn.call_after(self.get_destination_retry_timings.invalidate, (destination,)) + self.database_engine.lock_table(txn, "destinations") - self._simple_upsert_txn( + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + + # We need to be careful here as the data may have changed from under us + # due to a worker setting the timings. + + prev_row = self._simple_select_one_txn( txn, - "destinations", + table="destinations", keyvalues={ "destination": destination, }, - values={ - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - insertion_values={ - "destination": destination, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - } + retcols=("retry_last_ts", "retry_interval"), + allow_none=True, ) + if not prev_row: + self._simple_insert_txn( + txn, + table="destinations", + values={ + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + } + ) + elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + self._simple_update_one_txn( + txn, + "destinations", + keyvalues={ + "destination": destination, + }, + updatevalues={ + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + ) + def get_destinations_needing_retry(self): """Get all destinations which are due a retry for sending a transaction. diff --git a/synapse/types.py b/synapse/types.py index ffab12df09603744d0d3080f570504459f21b45f..3a3ab21d17ef80bb23fa2ae5a165e56fbee736dd 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) + + +class ThirdPartyInstanceID( + namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) +): + # Deny iteration because it will bite you if you try to create a singleton + # set by: + # users = set(user) + def __iter__(self): + raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) + + # Because this class is a namedtuple of strings, it is deeply immutable. + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + @classmethod + def from_string(cls, s): + bits = s.split("|", 2) + if len(bits) != 2: + raise SynapseError(400, "Invalid ID %r" % (s,)) + + return cls(appservice_id=bits[0], network_id=bits[1]) + + def to_string(self): + return "%s|%s" % (self.appservice_id, self.network_id,) + + __str__ = to_string + + @classmethod + def create(cls, appservice_id, network_id,): + return cls(appservice_id=appservice_id, network_id=network_id) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index c05b9450beaa4f04127167e0e30f6fab4b9ad77c..30fc480108eecddcc588b393578fe96fdddcaefd 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -24,6 +24,11 @@ import logging logger = logging.getLogger(__name__) +class DeferredTimedOutError(SynapseError): + def __init__(self): + super(SynapseError).__init__(504, "Timed out") + + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) @@ -89,7 +94,7 @@ class Clock(object): def timed_out_fn(): try: - ret_deferred.errback(SynapseError(504, "Timed out")) + ret_deferred.errback(DeferredTimedOutError()) except: pass diff --git a/synapse/util/async.py b/synapse/util/async.py index 347fb1e38096dd61f18e6502790b00b6a1d843d1..16ed183d4c1d15129ce93b7410200429bd0d5619 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -197,6 +197,64 @@ class Linearizer(object): defer.returnValue(_ctx_manager()) +class Limiter(object): + """Limits concurrent access to resources based on a key. Useful to ensure + only a few thing happen at a time on a given resource. + + Example: + + with (yield limiter.queue("test_key")): + # do some work. + + """ + def __init__(self, max_count): + """ + Args: + max_count(int): The maximum number of concurrent access + """ + self.max_count = max_count + + # key_to_defer is a map from the key to a 2 element list where + # the first element is the number of things executing + # the second element is a list of deferreds for the things blocked from + # executing. + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + entry = self.key_to_defer.setdefault(key, [0, []]) + + # If the number of things executing is greater than the maximum + # then add a deferred to the list of blocked items + # When on of the things currently executing finishes it will callback + # this item so that it can continue executing. + if entry[0] >= self.max_count: + new_defer = defer.Deferred() + entry[1].append(new_defer) + with PreserveLoggingContext(): + yield new_defer + + entry[0] += 1 + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + # We've finished executing so check if there are any things + # blocked waiting to execute and start one of them + entry[0] -= 1 + try: + entry[1].pop(0).callback(None) + except IndexError: + # If nothing else is executing for this key then remove it + # from the map + if entry[0] == 0: + self.key_to_defer.pop(key, None) + + defer.returnValue(_ctx_manager()) + + class ReadWriteLock(object): """A deferred style read write lock. diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 3fd5c3d9fd1b8d74249046b571008e59c2919613..d668e5a6b8cb606d435f5070dfd878b3f92329e9 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -76,15 +76,26 @@ class JsonEncodedObject(object): d.update(self.unrecognized_keys) return d + def get_internal_dict(self): + d = { + k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + if k in self.valid_keys + } + d.update(self.unrecognized_keys) + return d + def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) -def _encode(obj): +def _encode(obj, internal=False): if type(obj) is list: - return [_encode(o) for o in obj] + return [_encode(o, internal=internal) for o in obj] if isinstance(obj, JsonEncodedObject): - return obj.get_dict() + if internal: + return obj.get_internal_dict() + else: + return obj.get_dict() return obj diff --git a/synapse/util/ldap_auth_provider.py b/synapse/util/ldap_auth_provider.py deleted file mode 100644 index 1b989248fb1814f590e2c0f9457523e9a13ba9b3..0000000000000000000000000000000000000000 --- a/synapse/util/ldap_auth_provider.py +++ /dev/null @@ -1,369 +0,0 @@ - -from twisted.internet import defer - -from synapse.config._base import ConfigError -from synapse.types import UserID - -import ldap3 -import ldap3.core.exceptions - -import logging - -try: - import ldap3 - import ldap3.core.exceptions -except ImportError: - ldap3 = None - pass - - -logger = logging.getLogger(__name__) - - -class LDAPMode(object): - SIMPLE = "simple", - SEARCH = "search", - - LIST = (SIMPLE, SEARCH) - - -class LdapAuthProvider(object): - __version__ = "0.1" - - def __init__(self, config, account_handler): - self.account_handler = account_handler - - if not ldap3: - raise RuntimeError( - 'Missing ldap3 library. This is required for LDAP Authentication.' - ) - - self.ldap_mode = config.mode - self.ldap_uri = config.uri - self.ldap_start_tls = config.start_tls - self.ldap_base = config.base - self.ldap_attributes = config.attributes - if self.ldap_mode == LDAPMode.SEARCH: - self.ldap_bind_dn = config.bind_dn - self.ldap_bind_password = config.bind_password - self.ldap_filter = config.filter - - @defer.inlineCallbacks - def check_password(self, user_id, password): - """ Attempt to authenticate a user against an LDAP Server - and register an account if none exists. - - Returns: - True if authentication against LDAP was successful - """ - localpart = UserID.from_string(user_id).localpart - - try: - server = ldap3.Server(self.ldap_uri) - logger.debug( - "Attempting LDAP connection with %s", - self.ldap_uri - ) - - if self.ldap_mode == LDAPMode.SIMPLE: - result, conn = self._ldap_simple_bind( - server=server, localpart=localpart, password=password - ) - logger.debug( - 'LDAP authentication method simple bind returned: %s (conn: %s)', - result, - conn - ) - if not result: - defer.returnValue(False) - elif self.ldap_mode == LDAPMode.SEARCH: - result, conn = self._ldap_authenticated_search( - server=server, localpart=localpart, password=password - ) - logger.debug( - 'LDAP auth method authenticated search returned: %s (conn: %s)', - result, - conn - ) - if not result: - defer.returnValue(False) - else: - raise RuntimeError( - 'Invalid LDAP mode specified: {mode}'.format( - mode=self.ldap_mode - ) - ) - - try: - logger.info( - "User authenticated against LDAP server: %s", - conn - ) - except NameError: - logger.warn( - "Authentication method yielded no LDAP connection, aborting!" - ) - defer.returnValue(False) - - # check if user with user_id exists - if (yield self.account_handler.check_user_exists(user_id)): - # exists, authentication complete - conn.unbind() - defer.returnValue(True) - - else: - # does not exist, fetch metadata for account creation from - # existing ldap connection - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - - if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: - query = "(&{filter}{user_filter})".format( - filter=query, - user_filter=self.ldap_filter - ) - logger.debug( - "ldap registration filter: %s", - query - ) - - conn.search( - search_base=self.ldap_base, - search_filter=query, - attributes=[ - self.ldap_attributes['name'], - self.ldap_attributes['mail'] - ] - ) - - if len(conn.response) == 1: - attrs = conn.response[0]['attributes'] - mail = attrs[self.ldap_attributes['mail']][0] - name = attrs[self.ldap_attributes['name']][0] - - # create account - user_id, access_token = ( - yield self.account_handler.register(localpart=localpart) - ) - - # TODO: bind email, set displayname with data from ldap directory - - logger.info( - "Registration based on LDAP data was successful: %d: %s (%s, %)", - user_id, - localpart, - name, - mail - ) - - defer.returnValue(True) - else: - if len(conn.response) == 0: - logger.warn("LDAP registration failed, no result.") - else: - logger.warn( - "LDAP registration failed, too many results (%s)", - len(conn.response) - ) - - defer.returnValue(False) - - defer.returnValue(False) - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during ldap authentication: %s", e) - defer.returnValue(False) - - @staticmethod - def parse_config(config): - class _LdapConfig(object): - pass - - ldap_config = _LdapConfig() - - ldap_config.enabled = config.get("enabled", False) - - ldap_config.mode = LDAPMode.SIMPLE - - # verify config sanity - _require_keys(config, [ - "uri", - "base", - "attributes", - ]) - - ldap_config.uri = config["uri"] - ldap_config.start_tls = config.get("start_tls", False) - ldap_config.base = config["base"] - ldap_config.attributes = config["attributes"] - - if "bind_dn" in config: - ldap_config.mode = LDAPMode.SEARCH - _require_keys(config, [ - "bind_dn", - "bind_password", - ]) - - ldap_config.bind_dn = config["bind_dn"] - ldap_config.bind_password = config["bind_password"] - ldap_config.filter = config.get("filter", None) - - # verify attribute lookup - _require_keys(config['attributes'], [ - "uid", - "name", - "mail", - ]) - - return ldap_config - - def _ldap_simple_bind(self, server, localpart, password): - """ Attempt a simple bind with the credentials - given by the user against the LDAP server. - - Returns True, LDAP3Connection - if the bind was successful - Returns False, None - if an error occured - """ - - try: - # bind with the the local users ldap credentials - bind_dn = "{prop}={value},{base}".format( - prop=self.ldap_attributes['uid'], - value=localpart, - base=self.ldap_base - ) - conn = ldap3.Connection(server, bind_dn, password, - authentication=ldap3.AUTH_SIMPLE) - logger.debug( - "Established LDAP connection in simple bind mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded LDAP connection in simple bind mode through StartTLS: %s", - conn - ) - - if conn.bind(): - # GOOD: bind okay - logger.debug("LDAP Bind successful in simple bind mode.") - return True, conn - - # BAD: bind failed - logger.info( - "Binding against LDAP failed for '%s' failed: %s", - localpart, conn.result['description'] - ) - conn.unbind() - return False, None - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during LDAP authentication: %s", e) - return False, None - - def _ldap_authenticated_search(self, server, localpart, password): - """ Attempt to login with the preconfigured bind_dn - and then continue searching and filtering within - the base_dn - - Returns (True, LDAP3Connection) - if a single matching DN within the base was found - that matched the filter expression, and with which - a successful bind was achieved - - The LDAP3Connection returned is the instance that was used to - verify the password not the one using the configured bind_dn. - Returns (False, None) - if an error occured - """ - - try: - conn = ldap3.Connection( - server, - self.ldap_bind_dn, - self.ldap_bind_password - ) - logger.debug( - "Established LDAP connection in search mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded LDAP connection in search mode through StartTLS: %s", - conn - ) - - if not conn.bind(): - logger.warn( - "Binding against LDAP with `bind_dn` failed: %s", - conn.result['description'] - ) - conn.unbind() - return False, None - - # construct search_filter like (uid=localpart) - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - if self.ldap_filter: - # combine with the AND expression - query = "(&{query}{filter})".format( - query=query, - filter=self.ldap_filter - ) - logger.debug( - "LDAP search filter: %s", - query - ) - conn.search( - search_base=self.ldap_base, - search_filter=query - ) - - if len(conn.response) == 1: - # GOOD: found exactly one result - user_dn = conn.response[0]['dn'] - logger.debug('LDAP search found dn: %s', user_dn) - - # unbind and simple bind with user_dn to verify the password - # Note: do not use rebind(), for some reason it did not verify - # the password for me! - conn.unbind() - return self._ldap_simple_bind(server, localpart, password) - else: - # BAD: found 0 or > 1 results, abort! - if len(conn.response) == 0: - logger.info( - "LDAP search returned no results for '%s'", - localpart - ) - else: - logger.info( - "LDAP search returned too many (%s) results for '%s'", - len(conn.response), localpart - ) - conn.unbind() - return False, None - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during LDAP authentication: %s", e) - return False, None - - -def _require_keys(config, required): - missing = [key for key in required if key not in config] - if missing: - raise ConfigError( - "LDAP enabled but missing required config values: {}".format( - ", ".join(missing) - ) - ) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 49527f4d21da1cc57d8622a9a188dbbd965844fa..e2de7fce911a6ca0d76a5047ef18af2b766cf84e 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -121,15 +121,9 @@ class RetryDestinationLimiter(object): pass def __exit__(self, exc_type, exc_val, exc_tb): - def err(failure): - logger.exception( - "Failed to store set_destination_retry_timings", - failure.value - ) - valid_err_code = False if exc_type is not None and issubclass(exc_type, CodeMessageException): - valid_err_code = 0 <= exc_val.code < 500 + valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 if exc_type is None or valid_err_code: # We connected successfully. @@ -151,6 +145,15 @@ class RetryDestinationLimiter(object): retry_last_ts = int(self.clock.time_msec()) - self.store.set_destination_retry_timings( - self.destination, retry_last_ts, self.retry_interval - ).addErrback(err) + @defer.inlineCallbacks + def store_retry_timings(): + try: + yield self.store.set_destination_retry_timings( + self.destination, retry_last_ts, self.retry_interval + ) + except: + logger.exception( + "Failed to store set_destination_retry_timings", + ) + + store_retry_timings() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 2cf262bb466ac01c66c46908a061d0ed3c9ac6e1..4575dd98344ae988fee86e5bf0774cf08e8fa242 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -12,17 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest -from twisted.internet import defer +import pymacaroons from mock import Mock +from twisted.internet import defer +import synapse.handlers.auth from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver, mock_getRawHeaders -import pymacaroons + +class TestHandlers(object): + def __init__(self, hs): + self.auth_handler = synapse.handlers.auth.AuthHandler(hs) class AuthTestCase(unittest.TestCase): @@ -34,14 +39,17 @@ class AuthTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver(handlers=None) self.hs.get_datastore = Mock(return_value=self.store) + self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) self.test_user = "@foo:bar" self.test_token = "_test_token_" + # this is overridden for the appservice tests + self.store.get_app_service_by_token = Mock(return_value=None) + @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, "token_id": "ditto", @@ -56,7 +64,6 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) @@ -66,7 +73,6 @@ class AuthTestCase(unittest.TestCase): self.failureResultOf(d, AuthError) def test_get_user_by_req_user_missing_token(self): - self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, "token_id": "ditto", @@ -158,7 +164,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -168,6 +174,10 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): + self.store.get_user_by_id = Mock(return_value={ + "is_guest": True, + }) + user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, @@ -179,11 +189,12 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth.get_user_from_macaroon(serialized) + user_info = yield self.auth.get_user_by_access_token(serialized) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) self.assertTrue(is_guest) + self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks def test_get_user_from_macaroon_user_db_mismatch(self): @@ -200,7 +211,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("User mismatch", cm.exception.msg) @@ -220,7 +231,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("type = access") with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("No user caveat", cm.exception.msg) @@ -242,7 +253,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = %s" % (user,)) with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -265,7 +276,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("cunning > fox") with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -293,12 +304,12 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True - # yield self.auth.get_user_from_macaroon(macaroon.serialize()) + # yield self.auth.get_user_by_access_token(macaroon.serialize()) # TODO(daniel): Turn on the check that we validate expiration, when we # validate expiration (and remove the above line, which will start # throwing). with self.assertRaises(AuthError) as cm: - yield self.auth.get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_by_access_token(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -327,6 +338,58 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True - user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + + @defer.inlineCallbacks + def test_cannot_use_regular_token_as_guest(self): + USER_ID = "@percy:matrix.org" + self.store.add_access_token_to_user = Mock() + + token = yield self.hs.handlers.auth_handler.issue_access_token( + USER_ID, "DEVICE" + ) + self.store.add_access_token_to_user.assert_called_with( + USER_ID, token, "DEVICE" + ) + + def get_user(tok): + if token != tok: + return None + return { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + self.store.get_user_by_access_token = get_user + self.store.get_user_by_id = Mock(return_value={ + "is_guest": False, + }) + + # check the token works + request = Mock(args={}) + request.args["access_token"] = [token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + self.assertEqual(UserID.from_string(USER_ID), requester.user) + self.assertFalse(requester.is_guest) + + # add an is_guest caveat + mac = pymacaroons.Macaroon.deserialize(token) + mac.add_first_party_caveat("guest = true") + guest_tok = mac.serialize() + + # the token should *not* work now + request = Mock(args={}) + request.args["access_token"] = [guest_tok] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + with self.assertRaises(AuthError) as cm: + yield self.auth.get_user_by_req(request, allow_guest=True) + + self.assertEqual(401, cm.exception.code) + self.assertEqual("Guest access token used for regular user", cm.exception.msg) + + self.store.get_user_by_id.assert_called_with(USER_ID) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index fb0953c4ec097a697e2c502164d23c35623c1856..29f068d1f1514b41741effc371e567935f4a3e99 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -17,7 +17,11 @@ from .. import unittest from synapse.events import FrozenEvent -from synapse.events.utils import prune_event +from synapse.events.utils import prune_event, serialize_event + + +def MockEvent(**kwargs): + return FrozenEvent(kwargs) class PruneEventTestCase(unittest.TestCase): @@ -114,3 +118,167 @@ class PruneEventTestCase(unittest.TestCase): 'unsigned': {}, } ) + + +class SerializeEventTestCase(unittest.TestCase): + + def serialize(self, ev, fields): + return serialize_event(ev, 1479807801915, only_event_fields=fields) + + def test_event_fields_works_with_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar" + ), + ["room_id"] + ), + { + "room_id": "!foo:bar", + } + ) + + def test_event_fields_works_with_nested_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "body": "A message", + }, + ), + ["content.body"] + ), + { + "content": { + "body": "A message", + } + } + ) + + def test_event_fields_works_with_dot_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "key.with.dots": {}, + }, + ), + ["content.key\.with\.dots"] + ), + { + "content": { + "key.with.dots": {}, + } + } + ) + + def test_event_fields_works_with_nested_dot_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "not_me": 1, + "nested.dot.key": { + "leaf.key": 42, + "not_me_either": 1, + }, + }, + ), + ["content.nested\.dot\.key.leaf\.key"] + ), + { + "content": { + "nested.dot.key": { + "leaf.key": 42, + }, + } + } + ) + + def test_event_fields_nops_with_unknown_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "foo": "bar", + }, + ), + ["content.foo", "content.notexists"] + ), + { + "content": { + "foo": "bar", + } + } + ) + + def test_event_fields_nops_with_non_dict_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "foo": ["I", "am", "an", "array"], + }, + ), + ["content.foo.am"] + ), + {} + ) + + def test_event_fields_nops_with_array_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "foo": ["I", "am", "an", "array"], + }, + ), + ["content.foo.1"] + ), + {} + ) + + def test_event_fields_all_fields_if_empty(self): + self.assertEquals( + self.serialize( + MockEvent( + room_id="!foo:bar", + content={ + "foo": "bar", + }, + ), + [] + ), + { + "room_id": "!foo:bar", + "content": { + "foo": "bar", + }, + "unsigned": {} + } + ) + + def test_event_fields_fail_if_fields_not_str(self): + with self.assertRaises(TypeError): + self.serialize( + MockEvent( + room_id="!foo:bar", + content={ + "foo": "bar", + }, + ), + ["room_id", 4] + ) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 4a8cd19acf18b49262482886f165940a914c61f2..9d013e5ca77c500f437aeb45924373270a3a12a8 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -61,14 +61,14 @@ class AuthTestCase(unittest.TestCase): def verify_type(caveat): return caveat == "type = access" - def verify_expiry(caveat): - return caveat == "time < 8600000" + def verify_nonce(caveat): + return caveat.startswith("nonce =") v = pymacaroons.Verifier() v.satisfy_general(verify_gen) v.satisfy_general(verify_user) v.satisfy_general(verify_type) - v.satisfy_general(verify_expiry) + v.satisfy_general(verify_nonce) v.verify(macaroon, self.hs.config.macaroon_secret_key) def test_short_term_login_token_gives_user_id(self): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 9c9d1446906f9c25485b3985fd92fb13ffeeaaa1..a4380c48b406474bf009eb51110bffab9ab1006c 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -53,13 +53,12 @@ class RegistrationTestCase(unittest.TestCase): @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - duration_ms = 200 local_part = "someone" display_name = "someone" user_id = "@someone:test" requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name, duration_ms) + requester, local_part, display_name) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -71,12 +70,11 @@ class RegistrationTestCase(unittest.TestCase): user_id=frank.to_string(), token="jkv;g498752-43gj['eamb!-5", password_hash=None) - duration_ms = 200 local_part = "frank" display_name = "Frank" user_id = "@frank:test" requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name, duration_ms) + requester, local_part, display_name) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index f406934a62e4db62a0327b6444f805b80fea3e4c..93b9fad0125fa69449771d90e311344dcb1cfad4 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -103,7 +103,7 @@ class ReplicationResourceCase(unittest.TestCase): room_id = yield self.create_room() event_id = yield self.send_text_message(room_id, "Hello, World") get = self.get(receipts="-1") - yield self.hs.get_handlers().receipts_handler.received_client_receipt( + yield self.hs.get_receipts_handler().received_client_receipt( room_id, "m.read", self.user_id, event_id ) code, body = yield get diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b4a787c436059ef5cc8a50dc391e3bf7e7bf6d16..b6173ab2eee7f56b6b9d0a0c432707bc1ab14c21 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -67,8 +67,8 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=user_id ) - self.auth_handler.get_login_tuple_for_user_id = Mock( - return_value=(token, "kermits_refresh_token") + self.auth_handler.get_access_token_for_user_id = Mock( + return_value=token ) (code, result) = yield self.servlet.on_POST(self.request) @@ -76,11 +76,9 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, - "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname } self.assertDictContainsSubset(det_data, result) - self.assertIn("refresh_token", result) @defer.inlineCallbacks def test_POST_appservice_registration_invalid(self): @@ -126,8 +124,8 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.get_login_tuple_for_user_id = Mock( - return_value=(token, "kermits_refresh_token") + self.auth_handler.get_access_token_for_user_id = Mock( + return_value=token ) self.device_handler.check_device_registered = \ Mock(return_value=device_id) @@ -137,12 +135,10 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, - "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname, "device_id": device_id, } self.assertDictContainsSubset(det_data, result) - self.assertIn("refresh_token", result) self.auth_handler.get_login_tuple_for_user_id( user_id, device_id=device_id, initial_device_display_name=None) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 02a67b733d17aaf4b73b7104b101aecfb49d96f1..9ff1abcd80060ef984d00464297e28bf4ea01efe 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -39,7 +39,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config) + hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) self.as_token = "token1" self.as_url = "some_url" @@ -112,7 +112,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config) + hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) self.db_pool = hs.get_db_pool() self.as_list = [ @@ -443,7 +443,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) - hs = yield setup_test_homeserver(config=config, datastore=Mock()) + hs = yield setup_test_homeserver( + config=config, + datastore=Mock(), + federation_sender=Mock() + ) ApplicationServiceStore(hs) @@ -456,7 +460,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) - hs = yield setup_test_homeserver(config=config, datastore=Mock()) + hs = yield setup_test_homeserver( + config=config, + datastore=Mock(), + federation_sender=Mock() + ) with self.assertRaises(ConfigError) as cm: ApplicationServiceStore(hs) @@ -475,7 +483,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) - hs = yield setup_test_homeserver(config=config, datastore=Mock()) + hs = yield setup_test_homeserver( + config=config, + datastore=Mock(), + federation_sender=Mock() + ) with self.assertRaises(ConfigError) as cm: ApplicationServiceStore(hs) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index f7d74dea8e1f7a3d579cb0873713035bca243498..316ecdb32d4a48425bc43e810e6a205ece9205b2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,9 +17,6 @@ from tests import unittest from twisted.internet import defer -from synapse.api.errors import StoreError -from synapse.util import stringutils - from tests.utils import setup_test_homeserver @@ -80,64 +77,12 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertTrue("token_id" in result) - @defer.inlineCallbacks - def test_exchange_refresh_token_valid(self): - uid = stringutils.random_string(32) - device_id = stringutils.random_string(16) - generator = TokenGenerator() - last_token = generator.generate(uid) - - self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token, device_id) " - "VALUES(?,?,?)", - (uid, last_token, device_id)) - - (found_user_id, refresh_token, device_id) = \ - yield self.store.exchange_refresh_token(last_token, - generator.generate) - self.assertEqual(uid, found_user_id) - - rows = yield self.db_pool.runQuery( - "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", - (uid, )) - self.assertEqual([(refresh_token, device_id)], rows) - # We issued token 1, then exchanged it for token 2 - expected_refresh_token = u"%s-%d" % (uid, 2,) - self.assertEqual(expected_refresh_token, refresh_token) - - @defer.inlineCallbacks - def test_exchange_refresh_token_none(self): - uid = stringutils.random_string(32) - generator = TokenGenerator() - last_token = generator.generate(uid) - - with self.assertRaises(StoreError): - yield self.store.exchange_refresh_token(last_token, generator.generate) - - @defer.inlineCallbacks - def test_exchange_refresh_token_invalid(self): - uid = stringutils.random_string(32) - generator = TokenGenerator() - last_token = generator.generate(uid) - wrong_token = "%s-wrong" % (last_token,) - - self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, wrong_token,)) - - with self.assertRaises(StoreError): - yield self.store.exchange_refresh_token(last_token, generator.generate) - @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - generator = TokenGenerator() - refresh_token = generator.generate(self.user_id) yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], self.device_id) - yield self.store.add_refresh_token_to_user(self.user_id, refresh_token, - self.device_id) # now delete some yield self.store.user_delete_access_tokens( @@ -146,9 +91,6 @@ class RegistrationStoreTestCase(unittest.TestCase): # check they were deleted user = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertIsNone(user, "access token was not deleted by device_id") - with self.assertRaises(StoreError): - yield self.store.exchange_refresh_token(refresh_token, - generator.generate) # check the one not associated with the device was not deleted user = yield self.store.get_user_by_access_token(self.tokens[0]) diff --git a/tests/test_preview.py b/tests/test_preview.py index c8d6525a0122496059c9ebcf7cda335cdfd285dd..5bd36c74aacf4ae3f4a6cd5e6dc20553f90e8dd3 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -24,7 +24,7 @@ class PreviewTestCase(unittest.TestCase): def test_long_summarize(self): example_paras = [ - """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: + u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in Troms county, Norway. The administrative centre of the municipality is the city of Tromsø. Outside of Norway, Tromso and Tromsö are @@ -32,7 +32,7 @@ class PreviewTestCase(unittest.TestCase): city in the world with a population above 50,000. The most populous town north of it is Alta, Norway, with a population of 14,272 (2013).""", - """Tromsø lies in Northern Norway. The municipality has a population of + u"""Tromsø lies in Northern Norway. The municipality has a population of (2015) 72,066, but with an annual influx of students it has over 75,000 most of the year. It is the largest urban area in Northern Norway and the third largest north of the Arctic Circle (following Murmansk and Norilsk). @@ -46,7 +46,7 @@ class PreviewTestCase(unittest.TestCase): in Europe. The city is warmer than most other places located on the same latitude, due to the warming effect of the Gulf Stream.""", - """The city centre of Tromsø contains the highest number of old wooden + u"""The city centre of Tromsø contains the highest number of old wooden houses in Northern Norway, the oldest house dating from 1789. The Arctic Cathedral, a modern church from 1965, is probably the most famous landmark in Tromsø. The city is a cultural centre for its region, with several @@ -60,90 +60,90 @@ class PreviewTestCase(unittest.TestCase): self.assertEquals( desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway. The administrative centre of the municipality is" - " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" - " alternative spellings of the city.Tromsø is considered the northernmost" - " city in the world with a population above 50,000. The most populous town" - " north of it is Alta, Norway, with a population of 14,272 (2013)." + u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + u" Troms county, Norway. The administrative centre of the municipality is" + u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are" + u" alternative spellings of the city.Tromsø is considered the northernmost" + u" city in the world with a population above 50,000. The most populous town" + u" north of it is Alta, Norway, with a population of 14,272 (2013)." ) desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) self.assertEquals( desc, - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. It is the largest urban area in Northern Norway and the" - " third largest north of the Arctic Circle (following Murmansk and Norilsk)." - " Most of Tromsø, including the city centre, is located on the island of" - " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," - " Tromsøya had a population of 36,088. Substantial parts of the…" + u"Tromsø lies in Northern Norway. The municipality has a population of" + u" (2015) 72,066, but with an annual influx of students it has over 75,000" + u" most of the year. It is the largest urban area in Northern Norway and the" + u" third largest north of the Arctic Circle (following Murmansk and Norilsk)." + u" Most of Tromsø, including the city centre, is located on the island of" + u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," + u" Tromsøya had a population of 36,088. Substantial parts of the urban…" ) def test_short_summarize(self): example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year.", - - "The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", + u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + u" Troms county, Norway.", + + u"Tromsø lies in Northern Norway. The municipality has a population of" + u" (2015) 72,066, but with an annual influx of students it has over 75,000" + u" most of the year.", + + u"The city centre of Tromsø contains the highest number of old wooden" + u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" + u" Cathedral, a modern church from 1965, is probably the most famous landmark" + u" in Tromsø.", ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) self.assertEquals( desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year." + u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + u" Troms county, Norway.\n" + u"\n" + u"Tromsø lies in Northern Norway. The municipality has a population of" + u" (2015) 72,066, but with an annual influx of students it has over 75,000" + u" most of the year." ) def test_small_then_large_summarize(self): example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year." - " The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", + u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + u" Troms county, Norway.", + + u"Tromsø lies in Northern Norway. The municipality has a population of" + u" (2015) 72,066, but with an annual influx of students it has over 75,000" + u" most of the year." + u" The city centre of Tromsø contains the highest number of old wooden" + u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" + u" Cathedral, a modern church from 1965, is probably the most famous landmark" + u" in Tromsø.", ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) self.assertEquals( desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. The city centre of Tromsø contains the highest number" - " of old wooden houses in Northern Norway, the oldest house dating from" - " 1789. The Arctic Cathedral, a modern church…" + u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + u" Troms county, Norway.\n" + u"\n" + u"Tromsø lies in Northern Norway. The municipality has a population of" + u" (2015) 72,066, but with an annual influx of students it has over 75,000" + u" most of the year. The city centre of Tromsø contains the highest number" + u" of old wooden houses in Northern Norway, the oldest house dating from" + u" 1789. The Arctic Cathedral, a modern church from…" ) class PreviewUrlTestCase(unittest.TestCase): def test_simple(self): - html = """ + html = u""" <html> <head><title>Foo</title></head> <body> @@ -155,12 +155,12 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEquals(og, { - "og:title": "Foo", - "og:description": "Some text." + u"og:title": u"Foo", + u"og:description": u"Some text." }) def test_comment(self): - html = """ + html = u""" <html> <head><title>Foo</title></head> <body> @@ -173,12 +173,12 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEquals(og, { - "og:title": "Foo", - "og:description": "Some text." + u"og:title": u"Foo", + u"og:description": u"Some text." }) def test_comment2(self): - html = """ + html = u""" <html> <head><title>Foo</title></head> <body> @@ -194,12 +194,12 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEquals(og, { - "og:title": "Foo", - "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text" + u"og:title": u"Foo", + u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text" }) def test_script(self): - html = """ + html = u""" <html> <head><title>Foo</title></head> <body> @@ -212,6 +212,56 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEquals(og, { - "og:title": "Foo", - "og:description": "Some text." + u"og:title": u"Foo", + u"og:description": u"Some text." + }) + + def test_missing_title(self): + html = u""" + <html> + <body> + Some text. + </body> + </html> + """ + + og = decode_and_calc_og(html, "http://example.com/test.html") + + self.assertEquals(og, { + u"og:title": None, + u"og:description": u"Some text." + }) + + def test_h1_as_title(self): + html = u""" + <html> + <meta property="og:description" content="Some text."/> + <body> + <h1>Title</h1> + </body> + </html> + """ + + og = decode_and_calc_og(html, "http://example.com/test.html") + + self.assertEquals(og, { + u"og:title": u"Title", + u"og:description": u"Some text." + }) + + def test_missing_title_and_broken_h1(self): + html = u""" + <html> + <body> + <h1><a href="foo"/></h1> + Some text. + </body> + </html> + """ + + og = decode_and_calc_og(html, "http://example.com/test.html") + + self.assertEquals(og, { + u"og:title": None, + u"og:description": u"Some text." }) diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..9c795d9fdb08896293b7149e8e0a057d912017a4 --- /dev/null +++ b/tests/util/test_limiter.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from twisted.internet import defer + +from synapse.util.async import Limiter + + +class LimiterTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def test_limiter(self): + limiter = Limiter(3) + + key = object() + + d1 = limiter.queue(key) + cm1 = yield d1 + + d2 = limiter.queue(key) + cm2 = yield d2 + + d3 = limiter.queue(key) + cm3 = yield d3 + + d4 = limiter.queue(key) + self.assertFalse(d4.called) + + d5 = limiter.queue(key) + self.assertFalse(d5.called) + + with cm1: + self.assertFalse(d4.called) + self.assertFalse(d5.called) + + self.assertTrue(d4.called) + self.assertFalse(d5.called) + + with cm3: + self.assertFalse(d5.called) + + self.assertTrue(d5.called) + + with cm2: + pass + + with (yield d4): + pass + + with (yield d5): + pass + + d6 = limiter.queue(key) + with (yield d6): + pass diff --git a/tests/utils.py b/tests/utils.py index 5929f1c729b7fff8adcf4c4f72fceb8109659bc7..d3d6c8021de58d757ae27b1b45161dfa5516cb17 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -53,6 +53,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.trusted_third_party_id_servers = [] config.room_invite_state_types = [] config.password_providers = [] + config.worker_replication_url = "" + config.worker_app = None config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} @@ -70,6 +72,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): database_engine=create_engine(config.database_config), get_db_conn=db_pool.get_db_conn, room_list_handler=object(), + tls_server_context_factory=Mock(), **kargs ) hs.setup() @@ -79,6 +82,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): version_string="Synapse/tests", database_engine=create_engine(config.database_config), room_list_handler=object(), + tls_server_context_factory=Mock(), **kargs ) @@ -290,6 +294,10 @@ class MockClock(object): def advance_time_msec(self, ms): self.advance_time(ms / 1000.) + def time_bound_deferred(self, d, *args, **kwargs): + # We don't bother timing things out for now. + return d + class SQLiteMemoryDbPool(ConnectionPool, object): def __init__(self): diff --git a/tox.ini b/tox.ini index 52d93c65e52b53f976a3920ed19bb162d1bb8371..39ad30536055fe7f44c848c87022185d4445f3e8 100644 --- a/tox.ini +++ b/tox.ini @@ -8,8 +8,15 @@ deps = mock python-subunit junitxml + + # needed by some of the tests + lxml + setenv = PYTHONDONTWRITEBYTECODE = no_byte_code + # As of twisted 16.4, trial tries to import the tests as a package, which + # means it needs to be on the pythonpath. + PYTHONPATH = {toxinidir} commands = /bin/sh -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \ {envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"